添加PhysDock初始代码
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +204 -0
- .idea/.gitignore +8 -0
- .idea/PhysDock.iml +12 -0
- .idea/inspectionProfiles/Project_Default.xml +24 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- License +21 -0
- PhysDock/__init__.py +3 -0
- PhysDock/configs.py +195 -0
- PhysDock/configs_old.py +245 -0
- PhysDock/data/__init__.py +109 -0
- PhysDock/data/alignment_runner.py +937 -0
- PhysDock/data/alignment_runner_v2.py +327 -0
- PhysDock/data/constants/PDBData.py +348 -0
- PhysDock/data/constants/__init__.py +0 -0
- PhysDock/data/constants/periodic_table.py +27 -0
- PhysDock/data/constants/residue_constants.py +562 -0
- PhysDock/data/constants/restype_constants.py +107 -0
- PhysDock/data/feature_loader.py +1283 -0
- PhysDock/data/feature_loader_plinder.py +1258 -0
- PhysDock/data/generate_system.py +148 -0
- PhysDock/data/relaxation.py +259 -0
- PhysDock/data/tools/PDBData.py +348 -0
- PhysDock/data/tools/__init__.py +0 -0
- PhysDock/data/tools/alignment_runner.py +588 -0
- PhysDock/data/tools/convert_unifold_template_to_stfold.py +127 -0
- PhysDock/data/tools/dataset_manager.py +570 -0
- PhysDock/data/tools/feature_processing_multimer.py +257 -0
- PhysDock/data/tools/get_metrics.py +294 -0
- PhysDock/data/tools/hhblits.py +175 -0
- PhysDock/data/tools/hhsearch.py +126 -0
- PhysDock/data/tools/hmmalign.py +66 -0
- PhysDock/data/tools/hmmbuild.py +165 -0
- PhysDock/data/tools/hmmsearch.py +137 -0
- PhysDock/data/tools/jackhmmer.py +262 -0
- PhysDock/data/tools/kalign.py +114 -0
- PhysDock/data/tools/mmcif_parsing.py +519 -0
- PhysDock/data/tools/msa_identifiers.py +90 -0
- PhysDock/data/tools/msa_pairing.py +496 -0
- PhysDock/data/tools/nhmmer.py +257 -0
- PhysDock/data/tools/parse_msas.py +328 -0
- PhysDock/data/tools/parsers.py +727 -0
- PhysDock/data/tools/rdkit.py +220 -0
- PhysDock/data/tools/residue_constants.py +604 -0
- PhysDock/data/tools/templates.py +1357 -0
- PhysDock/data/tools/utils.py +48 -0
- PhysDock/models/__init__.py +0 -0
- PhysDock/models/layers/__init__.py +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Marimo
|
| 198 |
+
marimo/_static/
|
| 199 |
+
marimo/_lsp/
|
| 200 |
+
__marimo__/
|
| 201 |
+
|
| 202 |
+
# Streamlit
|
| 203 |
+
.streamlit/secrets.toml
|
| 204 |
+
params/params.pt
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/PhysDock.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="PhysDock" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="GOOGLE" />
|
| 10 |
+
<option name="myDocStringFormat" value="Google" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 5 |
+
<option name="ignoredPackages">
|
| 6 |
+
<value>
|
| 7 |
+
<list size="11">
|
| 8 |
+
<item index="0" class="java.lang.String" itemvalue="tqdm" />
|
| 9 |
+
<item index="1" class="java.lang.String" itemvalue="scipy" />
|
| 10 |
+
<item index="2" class="java.lang.String" itemvalue="deepspeed" />
|
| 11 |
+
<item index="3" class="java.lang.String" itemvalue="PyYAML" />
|
| 12 |
+
<item index="4" class="java.lang.String" itemvalue="pytorch_lightning" />
|
| 13 |
+
<item index="5" class="java.lang.String" itemvalue="ml-collections" />
|
| 14 |
+
<item index="6" class="java.lang.String" itemvalue="torch" />
|
| 15 |
+
<item index="7" class="java.lang.String" itemvalue="typing-extensions" />
|
| 16 |
+
<item index="8" class="java.lang.String" itemvalue="numpy" />
|
| 17 |
+
<item index="9" class="java.lang.String" itemvalue="requests" />
|
| 18 |
+
<item index="10" class="java.lang.String" itemvalue="dm-tree" />
|
| 19 |
+
</list>
|
| 20 |
+
</value>
|
| 21 |
+
</option>
|
| 22 |
+
</inspection_tool>
|
| 23 |
+
</profile>
|
| 24 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="PhysDock" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="PhysDock" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/PhysDock.iml" filepath="$PROJECT_DIR$/.idea/PhysDock.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
License
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 ShanghaiTech University
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
PhysDock/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PhysDock.models.model import PhysDock
|
| 2 |
+
from PhysDock.models.loss import PhysDockLoss
|
| 3 |
+
from PhysDock.configs import PhysDockConfig
|
PhysDock/configs.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ml_collections as mlc
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def PhysDockConfig(
|
| 5 |
+
inference_mode=True,
|
| 6 |
+
model_name="medium",
|
| 7 |
+
num_augmentation_sample=48,
|
| 8 |
+
|
| 9 |
+
crop_size=256,
|
| 10 |
+
atom_crop_size=256 * 8,
|
| 11 |
+
|
| 12 |
+
alpha_confifdence=1e-4,
|
| 13 |
+
alpha_diffusion=4,
|
| 14 |
+
alpha_bond=0,
|
| 15 |
+
alpha_distogram=3e-2,
|
| 16 |
+
alpha_pae=0,
|
| 17 |
+
inf=1e9,
|
| 18 |
+
eps=1e-8,
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Inference Config
|
| 22 |
+
infer_pocket_type="atom", # "ca"
|
| 23 |
+
infer_pocket_cutoff=6, # 8 10 12
|
| 24 |
+
infer_pocket_dist_type="ligand", # "ligand_centre"
|
| 25 |
+
infer_use_pocket=True,
|
| 26 |
+
infer_use_key_res=True,
|
| 27 |
+
|
| 28 |
+
# Training Config
|
| 29 |
+
train_pocket_type_atom_ratio=0.5,
|
| 30 |
+
train_pocket_cutoff_ligand_min=6,
|
| 31 |
+
train_pocket_cutoff_ligand_max=12,
|
| 32 |
+
train_pocket_cutoff_ligand_centre_min=10,
|
| 33 |
+
train_pocket_cutoff_ligand_centre_max=16,
|
| 34 |
+
train_pocket_dist_type_ligand_ratio=0.5,
|
| 35 |
+
train_use_pocket_ratio=0.5,
|
| 36 |
+
train_use_key_res_ratio=0.5,
|
| 37 |
+
|
| 38 |
+
train_shuffle_sym_id=True,
|
| 39 |
+
train_spatial_crop_ligand_ratio=0.2,
|
| 40 |
+
train_spatial_crop_interface_ratio=0.4,
|
| 41 |
+
train_spatial_crop_interface_threshold=15.,
|
| 42 |
+
train_charility_augmentation_ratio=0.1,
|
| 43 |
+
train_use_template_ratio=0.75,
|
| 44 |
+
train_template_mask_max_ratio=0.4,
|
| 45 |
+
|
| 46 |
+
# Other Configs
|
| 47 |
+
max_msa_clusters=128,
|
| 48 |
+
key_res_random_mask_ratio=0.5,
|
| 49 |
+
token_bond_threshold=2.4,
|
| 50 |
+
sigma_data=16.,
|
| 51 |
+
):
|
| 52 |
+
ref_dim = 167
|
| 53 |
+
target_dim = 65
|
| 54 |
+
msa_dim = 34
|
| 55 |
+
|
| 56 |
+
inf = inf
|
| 57 |
+
eps = eps
|
| 58 |
+
|
| 59 |
+
c_m = 256 # 256
|
| 60 |
+
c_s = 512 # 1024
|
| 61 |
+
c_z = 128 # 64 | 128
|
| 62 |
+
c_a = 128 # 128
|
| 63 |
+
c_ap = 16 # 16 | 32
|
| 64 |
+
|
| 65 |
+
if model_name == "toy":
|
| 66 |
+
no_blocks_atom = 2
|
| 67 |
+
no_blocks_evoformer = 2
|
| 68 |
+
no_blocks_pairformer = 2
|
| 69 |
+
no_blocks_dit = 2
|
| 70 |
+
no_blocks_heads = 2
|
| 71 |
+
elif model_name == "tiny":
|
| 72 |
+
no_blocks_atom = 2
|
| 73 |
+
no_blocks_evoformer = 2
|
| 74 |
+
no_blocks_pairformer = 8
|
| 75 |
+
no_blocks_dit = 4
|
| 76 |
+
no_blocks_heads = 2
|
| 77 |
+
elif model_name == "small":
|
| 78 |
+
no_blocks_atom = 2
|
| 79 |
+
no_blocks_evoformer = 3
|
| 80 |
+
no_blocks_pairformer = 16
|
| 81 |
+
no_blocks_dit = 8
|
| 82 |
+
no_blocks_heads = 2
|
| 83 |
+
elif model_name == "medium":
|
| 84 |
+
no_blocks_atom = 3
|
| 85 |
+
no_blocks_evoformer = 4
|
| 86 |
+
no_blocks_pairformer = 24
|
| 87 |
+
no_blocks_dit = 12
|
| 88 |
+
no_blocks_heads = 3
|
| 89 |
+
elif model_name == "full":
|
| 90 |
+
no_blocks_atom = 3
|
| 91 |
+
no_blocks_evoformer = 4
|
| 92 |
+
no_blocks_pairformer = 48
|
| 93 |
+
no_blocks_dit = 24
|
| 94 |
+
no_blocks_heads = 4
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError("Unknown model name")
|
| 97 |
+
|
| 98 |
+
config = {
|
| 99 |
+
"inference_mode": inference_mode,
|
| 100 |
+
"sigma_data": sigma_data,
|
| 101 |
+
"data": {
|
| 102 |
+
"crop_size": crop_size,
|
| 103 |
+
"atom_crop_size": atom_crop_size,
|
| 104 |
+
"max_msa_seqs": 16384,
|
| 105 |
+
"max_uniprot_msa_seqs": 8192,
|
| 106 |
+
"interface_threshold": 15,
|
| 107 |
+
"token_bond_threshold": token_bond_threshold,
|
| 108 |
+
"covalent_bond_threshold": 1.8,
|
| 109 |
+
"max_msa_clusters": max_msa_clusters,
|
| 110 |
+
"resample_msa_in_recycling": True,
|
| 111 |
+
},
|
| 112 |
+
"model": {
|
| 113 |
+
"c_z": c_z,
|
| 114 |
+
"num_augmentation_sample": num_augmentation_sample,
|
| 115 |
+
"diffusion_conditioning": {
|
| 116 |
+
"ref_dim": ref_dim,
|
| 117 |
+
"target_dim": target_dim,
|
| 118 |
+
"msa_dim": msa_dim,
|
| 119 |
+
"c_a": c_a,
|
| 120 |
+
"c_ap": c_ap,
|
| 121 |
+
"c_s": c_s,
|
| 122 |
+
"c_m": c_m,
|
| 123 |
+
"c_z": c_z,
|
| 124 |
+
"inf": inf,
|
| 125 |
+
"eps": eps,
|
| 126 |
+
"no_blocks_atom": no_blocks_atom,
|
| 127 |
+
"no_blocks_evoformer": no_blocks_evoformer,
|
| 128 |
+
"no_blocks_pairformer": no_blocks_pairformer
|
| 129 |
+
},
|
| 130 |
+
"dit": {
|
| 131 |
+
"c_a": c_a,
|
| 132 |
+
"c_ap": c_ap,
|
| 133 |
+
"c_s": c_s,
|
| 134 |
+
"c_z": c_z,
|
| 135 |
+
"inf": inf,
|
| 136 |
+
"eps": eps,
|
| 137 |
+
"no_blocks_atom": no_blocks_atom,
|
| 138 |
+
"no_blocks_dit": no_blocks_dit,
|
| 139 |
+
"sigma_data": sigma_data
|
| 140 |
+
},
|
| 141 |
+
"confidence_module": {
|
| 142 |
+
"c_a": c_a,
|
| 143 |
+
"c_ap": c_ap,
|
| 144 |
+
"c_s": c_s,
|
| 145 |
+
"c_z": c_z,
|
| 146 |
+
"inf": inf,
|
| 147 |
+
"eps": eps,
|
| 148 |
+
"no_blocks_heads": no_blocks_heads,
|
| 149 |
+
"no_blocks_atom": no_blocks_atom,
|
| 150 |
+
}
|
| 151 |
+
},
|
| 152 |
+
"loss": {
|
| 153 |
+
"weighted_mse_loss": {
|
| 154 |
+
"weight": alpha_diffusion,
|
| 155 |
+
"sigma_data": sigma_data,
|
| 156 |
+
"alpha_dna": 5.0,
|
| 157 |
+
"alpha_rna": 5.0,
|
| 158 |
+
"alpha_ligand": 10.0,
|
| 159 |
+
},
|
| 160 |
+
"smooth_lddt_loss": {
|
| 161 |
+
"weight": alpha_diffusion,
|
| 162 |
+
"max_clamp_distance": 15.,
|
| 163 |
+
},
|
| 164 |
+
|
| 165 |
+
"bond_loss": {
|
| 166 |
+
"weight": alpha_diffusion * alpha_bond,
|
| 167 |
+
"sigma_data": sigma_data,
|
| 168 |
+
},
|
| 169 |
+
"key_res_loss": {
|
| 170 |
+
"weight": alpha_diffusion * alpha_bond,
|
| 171 |
+
"sigma_data": sigma_data,
|
| 172 |
+
},
|
| 173 |
+
"distogram_loss": {
|
| 174 |
+
"weight": alpha_distogram,
|
| 175 |
+
"min_bin": 3.25,
|
| 176 |
+
"max_bin": 50.75,
|
| 177 |
+
"no_bins": 39,
|
| 178 |
+
"eps": 1e-9,
|
| 179 |
+
},
|
| 180 |
+
"plddt_loss": {
|
| 181 |
+
"weight": alpha_confifdence,
|
| 182 |
+
"no_bins": 50,
|
| 183 |
+
},
|
| 184 |
+
"pae_loss": {
|
| 185 |
+
"weight": alpha_confifdence * alpha_pae,
|
| 186 |
+
},
|
| 187 |
+
"pde_loss": {
|
| 188 |
+
"weight": alpha_confifdence,
|
| 189 |
+
"min_bin": 0,
|
| 190 |
+
"max_bin": 32,
|
| 191 |
+
"no_bins": 64,
|
| 192 |
+
},
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
return mlc.ConfigDict(config)
|
PhysDock/configs_old.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ml_collections as mlc
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def model_config(
|
| 5 |
+
model_name="full",
|
| 6 |
+
max_recycling_iters=1, # 0
|
| 7 |
+
max_msa_clusters=128, # 32
|
| 8 |
+
crop_size=256, #
|
| 9 |
+
num_augmentation_sample=48, # 128
|
| 10 |
+
alpha_confifdence=1e-4,
|
| 11 |
+
alpha_diffusion=4,
|
| 12 |
+
alpha_bond=0,
|
| 13 |
+
alpha_distogram=3e-2,
|
| 14 |
+
alpha_pae=0,
|
| 15 |
+
use_template=True, # False
|
| 16 |
+
use_mini_rollout=True, # False
|
| 17 |
+
use_flash_attn=False, # False
|
| 18 |
+
custom_rel_token=-1, # 42
|
| 19 |
+
ref_dim=1 + 2 + 2 + 128 + 256, # 167
|
| 20 |
+
mini_rollout_steps=20,
|
| 21 |
+
atom_attention_type="full",
|
| 22 |
+
templ_dim=108,
|
| 23 |
+
interaction_aware=True,
|
| 24 |
+
):
|
| 25 |
+
sigma_data = 16
|
| 26 |
+
# ref_dim = 1 + 2 + 2 + 128 + 256
|
| 27 |
+
msa_dim = 34
|
| 28 |
+
templ_dim = templ_dim
|
| 29 |
+
|
| 30 |
+
inf = 1e9
|
| 31 |
+
eps = 1e-8
|
| 32 |
+
|
| 33 |
+
pair_dropout = 0.25
|
| 34 |
+
msa_dropout = 0.15
|
| 35 |
+
|
| 36 |
+
c_m = 256 # 256
|
| 37 |
+
c_s = 768 # 1024
|
| 38 |
+
c_z = 128 # 64 | 128
|
| 39 |
+
c_tz = 64
|
| 40 |
+
c_a = 128 # 128
|
| 41 |
+
c_ap = 16 # 16 | 32
|
| 42 |
+
|
| 43 |
+
no_blocks_templ = 2
|
| 44 |
+
no_blocks_evo = 48
|
| 45 |
+
no_blocks_atom = 3
|
| 46 |
+
no_blocks_dit = 24
|
| 47 |
+
no_blocks_heads = 4
|
| 48 |
+
if model_name == "small_toy":
|
| 49 |
+
no_blocks_templ = 1
|
| 50 |
+
no_blocks_evo = 1
|
| 51 |
+
no_blocks_atom = 1
|
| 52 |
+
no_blocks_dit = 1
|
| 53 |
+
no_blocks_heads = 1
|
| 54 |
+
elif model_name == "toy":
|
| 55 |
+
no_blocks_templ = 2
|
| 56 |
+
no_blocks_evo = 2
|
| 57 |
+
no_blocks_atom = 2
|
| 58 |
+
no_blocks_dit = 2
|
| 59 |
+
no_blocks_heads = 2
|
| 60 |
+
|
| 61 |
+
elif model_name == "small":
|
| 62 |
+
no_blocks_templ = 2
|
| 63 |
+
no_blocks_evo = 4
|
| 64 |
+
no_blocks_atom = 2
|
| 65 |
+
no_blocks_dit = 2
|
| 66 |
+
no_blocks_heads = 2
|
| 67 |
+
elif model_name == "docking":
|
| 68 |
+
no_blocks_templ = 2
|
| 69 |
+
no_blocks_evo = 8
|
| 70 |
+
no_blocks_atom = 2
|
| 71 |
+
no_blocks_dit = 4
|
| 72 |
+
no_blocks_heads = 2
|
| 73 |
+
elif model_name == "medium":
|
| 74 |
+
no_blocks_templ = 2
|
| 75 |
+
no_blocks_evo = 16
|
| 76 |
+
no_blocks_atom = 3
|
| 77 |
+
no_blocks_dit = 8
|
| 78 |
+
no_blocks_heads = 2
|
| 79 |
+
elif model_name == "large":
|
| 80 |
+
no_blocks_templ = 2
|
| 81 |
+
no_blocks_evo = 24
|
| 82 |
+
no_blocks_atom = 3
|
| 83 |
+
no_blocks_dit = 12
|
| 84 |
+
no_blocks_heads = 4
|
| 85 |
+
elif model_name == "full":
|
| 86 |
+
no_blocks_templ = 2
|
| 87 |
+
no_blocks_evo = 48
|
| 88 |
+
no_blocks_atom = 3
|
| 89 |
+
no_blocks_dit = 24
|
| 90 |
+
no_blocks_heads = 4
|
| 91 |
+
|
| 92 |
+
return mlc.ConfigDict({
|
| 93 |
+
"use_template": use_template,
|
| 94 |
+
"use_mini_rollout": use_mini_rollout,
|
| 95 |
+
"mini_rollout_steps": mini_rollout_steps,
|
| 96 |
+
|
| 97 |
+
"data": {
|
| 98 |
+
"crop_size": crop_size,
|
| 99 |
+
"atom_crop_factor": 10,
|
| 100 |
+
"max_msa_seqs": 16384,
|
| 101 |
+
"max_uniprot_msa_seqs": 8192,
|
| 102 |
+
"interface_threshold": 15,
|
| 103 |
+
"token_bond_threshold": 2.4,
|
| 104 |
+
"covalent_bond_threshold": 1.8,
|
| 105 |
+
"max_msa_clusters": max_msa_clusters,
|
| 106 |
+
"resample_msa_in_recycling": True,
|
| 107 |
+
"max_recycling_iters": max_recycling_iters, # TODO 3
|
| 108 |
+
"sample_msa": {
|
| 109 |
+
"max_msa_clusters": 128,
|
| 110 |
+
"resample_msa_in_recycling": True,
|
| 111 |
+
},
|
| 112 |
+
"make_crop_ids": {
|
| 113 |
+
"crop_size": 384
|
| 114 |
+
}
|
| 115 |
+
},
|
| 116 |
+
"model": {
|
| 117 |
+
"input_feature_embedder": {
|
| 118 |
+
"msa_dim": msa_dim,
|
| 119 |
+
"ref_dim": ref_dim,
|
| 120 |
+
"c_s": c_s,
|
| 121 |
+
"c_m": c_m,
|
| 122 |
+
"c_z": c_z,
|
| 123 |
+
"c_ap": c_ap,
|
| 124 |
+
"c_a": c_a,
|
| 125 |
+
"no_heads": 4,
|
| 126 |
+
"c_hidden": 16,
|
| 127 |
+
"inf": inf,
|
| 128 |
+
"eps": eps,
|
| 129 |
+
"no_blocks": 3,
|
| 130 |
+
"interaction_aware": interaction_aware,
|
| 131 |
+
"custom_rel_token": custom_rel_token,
|
| 132 |
+
},
|
| 133 |
+
"template_pair_embedder": {
|
| 134 |
+
"templ_dim": templ_dim,
|
| 135 |
+
"c_z": c_z,
|
| 136 |
+
"c_tz": c_tz,
|
| 137 |
+
"c_hidden_tz": 16,
|
| 138 |
+
"no_heads_tz": 4,
|
| 139 |
+
"inf": inf,
|
| 140 |
+
"eps": eps,
|
| 141 |
+
"no_blocks": no_blocks_templ,
|
| 142 |
+
},
|
| 143 |
+
"recycling_embedder": {
|
| 144 |
+
"c_m": c_m,
|
| 145 |
+
"c_z": c_z,
|
| 146 |
+
},
|
| 147 |
+
"evoformer_stack": {
|
| 148 |
+
"c_m": c_m,
|
| 149 |
+
"c_z": c_z,
|
| 150 |
+
"c_hidden_m": 32,
|
| 151 |
+
"no_heads_m": 8,
|
| 152 |
+
"c_hidden_z": 32,
|
| 153 |
+
"no_heads_z": 4,
|
| 154 |
+
"c_hidden_opm": 32,
|
| 155 |
+
"inf": inf,
|
| 156 |
+
"eps": eps,
|
| 157 |
+
"no_blocks": no_blocks_evo,
|
| 158 |
+
"single_mode": False,
|
| 159 |
+
},
|
| 160 |
+
"diffusion_module": {
|
| 161 |
+
"ref_dim": ref_dim,
|
| 162 |
+
"c_m": c_m,
|
| 163 |
+
"c_s": c_s,
|
| 164 |
+
"c_z": c_z,
|
| 165 |
+
"c_a": c_a,
|
| 166 |
+
"c_ap": c_ap,
|
| 167 |
+
"no_heads_atom": 4,
|
| 168 |
+
"c_hidden_atom": 16,
|
| 169 |
+
"no_heads": c_ap,
|
| 170 |
+
"c_hidden": 32,
|
| 171 |
+
"inf": inf,
|
| 172 |
+
"eps": eps,
|
| 173 |
+
"no_blocks": no_blocks_dit,
|
| 174 |
+
"no_blocks_atom": no_blocks_atom,
|
| 175 |
+
"num_augmentation_sample": num_augmentation_sample,
|
| 176 |
+
"custom_rel_token": custom_rel_token,
|
| 177 |
+
"use_flash_attn": use_flash_attn,
|
| 178 |
+
"atom_attention_type": atom_attention_type
|
| 179 |
+
},
|
| 180 |
+
"confidence_module": {
|
| 181 |
+
"c_a": c_a,
|
| 182 |
+
"c_ap": c_ap,
|
| 183 |
+
"c_s": c_s,
|
| 184 |
+
"c_m": c_m,
|
| 185 |
+
"c_z": c_z,
|
| 186 |
+
"no_heads_a": 4,
|
| 187 |
+
"c_hidden_a": 16,
|
| 188 |
+
"c_hidden_m": 32,
|
| 189 |
+
"no_heads_m": 8,
|
| 190 |
+
"c_hidden_z": 32,
|
| 191 |
+
"no_heads_z": 4,
|
| 192 |
+
"c_hidden_opm": 32,
|
| 193 |
+
"inf": inf,
|
| 194 |
+
"eps": eps,
|
| 195 |
+
"no_blocks": no_blocks_heads,
|
| 196 |
+
"no_blocks_atom": no_blocks_atom,
|
| 197 |
+
"c_pae": 64,
|
| 198 |
+
"c_pde": 64,
|
| 199 |
+
"c_plddt": 50,
|
| 200 |
+
"c_distogram": 39,
|
| 201 |
+
},
|
| 202 |
+
"loss": {
|
| 203 |
+
"weighted_mse_loss": {
|
| 204 |
+
"weight": alpha_diffusion,
|
| 205 |
+
"sigma_data": sigma_data,
|
| 206 |
+
"alpha_dna": 5.0,
|
| 207 |
+
"alpha_rna": 5.0,
|
| 208 |
+
"alpha_ligand": 10.0,
|
| 209 |
+
},
|
| 210 |
+
"smooth_lddt_loss": {
|
| 211 |
+
"weight": alpha_diffusion,
|
| 212 |
+
"max_clamp_distance": 15.,
|
| 213 |
+
},
|
| 214 |
+
"clamp_distance_loss": {
|
| 215 |
+
"weight": alpha_diffusion * 0.2,
|
| 216 |
+
"max_clamp_distance": 10,
|
| 217 |
+
},
|
| 218 |
+
|
| 219 |
+
"bond_loss": {
|
| 220 |
+
"weight": alpha_diffusion * alpha_bond,
|
| 221 |
+
"sigma_data": sigma_data,
|
| 222 |
+
},
|
| 223 |
+
"distogram_loss": {
|
| 224 |
+
"weight": alpha_distogram,
|
| 225 |
+
"min_bin": 3.25,
|
| 226 |
+
"max_bin": 50.75,
|
| 227 |
+
"no_bins": 39,
|
| 228 |
+
"eps": 1e-9,
|
| 229 |
+
},
|
| 230 |
+
"plddt_loss": {
|
| 231 |
+
"weight": alpha_confifdence,
|
| 232 |
+
"no_bins": 50,
|
| 233 |
+
},
|
| 234 |
+
"pae_loss": {
|
| 235 |
+
"weight": alpha_confifdence * alpha_pae,
|
| 236 |
+
},
|
| 237 |
+
"pde_loss": {
|
| 238 |
+
"weight": alpha_confifdence,
|
| 239 |
+
"min_bin": 0,
|
| 240 |
+
"max_bin": 32,
|
| 241 |
+
"no_bins": 64,
|
| 242 |
+
},
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
})
|
PhysDock/data/__init__.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Union, Any
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.sparse.coo import coo_matrix
|
| 5 |
+
|
| 6 |
+
# TODO: Keep Only ref_mask eq 1 for all atomwise feature
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
'''
|
| 10 |
+
Notation:
|
| 11 |
+
batch,
|
| 12 |
+
token dimension: i, j, k
|
| 13 |
+
flat atom dimension: l, m | WARNING: We should flatten due to local atom attention mask
|
| 14 |
+
sequence dimension: s(msa) t(time)
|
| 15 |
+
head dimension: h
|
| 16 |
+
|
| 17 |
+
####################
|
| 18 |
+
z_ij: pair repr
|
| 19 |
+
{z_ij}: all pair repr
|
| 20 |
+
x: atom position
|
| 21 |
+
{x_l}: flat atom list, full atomic structure
|
| 22 |
+
exist a mapping: flat atom index -> token index and within token atom index: l -> i, a
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
####################
|
| 26 |
+
a: atom representation
|
| 27 |
+
have the same shape as s
|
| 28 |
+
exist a transform: flat atom representation -> atom represenation
|
| 29 |
+
s: token representation
|
| 30 |
+
z: pair representation
|
| 31 |
+
|
| 32 |
+
'''
|
| 33 |
+
|
| 34 |
+
FeatureDict = Dict[str, Union[np.ndarray, coo_matrix, None, Any]]
|
| 35 |
+
TensorDict = Dict[str, Union[torch.Tensor, Any]]
|
| 36 |
+
|
| 37 |
+
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
| 38 |
+
|
| 39 |
+
NUM_CONFORMER = "num conformer placeholder"
|
| 40 |
+
NUM_TOKEN = "num tokens placeholder"
|
| 41 |
+
NUM_ATOM = "num atoms placeholder"
|
| 42 |
+
|
| 43 |
+
NUM_SEQ = "num MSAs placeholder"
|
| 44 |
+
NUM_TEMPL = "num templates placeholder"
|
| 45 |
+
|
| 46 |
+
NUM_RECYCLING = "num recycling placeholder"
|
| 47 |
+
NUM_SAMPLE = "num sample placeholder"
|
| 48 |
+
|
| 49 |
+
SHAPE_SCHIME = {
|
| 50 |
+
################################################################
|
| 51 |
+
# Conformerwise Feature
|
| 52 |
+
|
| 53 |
+
# Tokenwise Feature
|
| 54 |
+
"residue_index": [NUM_TOKEN],
|
| 55 |
+
"restype": [NUM_TOKEN],
|
| 56 |
+
"token_index": [NUM_TOKEN],
|
| 57 |
+
"s_mask": [NUM_TOKEN],
|
| 58 |
+
"is_protein": [NUM_TOKEN],
|
| 59 |
+
"is_rna": [NUM_TOKEN],
|
| 60 |
+
"is_dna": [NUM_TOKEN],
|
| 61 |
+
"is_ligand": [NUM_TOKEN],
|
| 62 |
+
"token_id_to_centre_atom_id": [NUM_TOKEN],
|
| 63 |
+
"token_id_to_pseudo_beta_atom_id": [NUM_TOKEN],
|
| 64 |
+
"token_id_to_chunk_sizes": [NUM_TOKEN],
|
| 65 |
+
"token_id_to_conformer_id": [NUM_TOKEN],
|
| 66 |
+
"asym_id": [NUM_TOKEN],
|
| 67 |
+
"entity_id": [NUM_TOKEN],
|
| 68 |
+
"sym_id": [NUM_TOKEN],
|
| 69 |
+
"token_bonds": [NUM_TOKEN, NUM_TOKEN],
|
| 70 |
+
"target_feat": [NUM_TOKEN],
|
| 71 |
+
"token_exists": [NUM_TOKEN],
|
| 72 |
+
"spatial_crop_target_res_mask": [NUM_TOKEN],
|
| 73 |
+
|
| 74 |
+
# Atomwise features
|
| 75 |
+
"ref_space_uid": [NUM_ATOM],
|
| 76 |
+
"atom_index": [NUM_ATOM],
|
| 77 |
+
"ref_feat": [NUM_ATOM, 389],
|
| 78 |
+
"ref_pos": [NUM_ATOM, 3],
|
| 79 |
+
"a_mask": [NUM_ATOM],
|
| 80 |
+
"atom_id_to_token_id": [NUM_ATOM],
|
| 81 |
+
"x_gt": [NUM_ATOM, 3],
|
| 82 |
+
"x_exists": [NUM_ATOM],
|
| 83 |
+
"rec_mask": [NUM_ATOM, NUM_ATOM],
|
| 84 |
+
|
| 85 |
+
"msa": [NUM_SEQ, NUM_TOKEN],
|
| 86 |
+
"deletion_matrix": [NUM_SEQ, NUM_TOKEN],
|
| 87 |
+
"msa_feat": [NUM_SEQ, NUM_TOKEN, None],
|
| 88 |
+
"crop_idx": [None],
|
| 89 |
+
"crop_idx_atom": [None],
|
| 90 |
+
|
| 91 |
+
#
|
| 92 |
+
"x_centre": [None],
|
| 93 |
+
# # Template features
|
| 94 |
+
# "template_restype": [NUM_TEMPLATES, NUM_TOKEN],
|
| 95 |
+
# "template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_TOKEN],
|
| 96 |
+
# "template_backbone_frame_mask": [NUM_TEMPLATES, NUM_TOKEN],
|
| 97 |
+
# "template_distogram": [NUM_TEMPLATES, NUM_TOKEN, NUM_TOKEN, 39],
|
| 98 |
+
# "template_unit_vector": [NUM_TEMPLATES, NUM_TOKEN, NUM_TOKEN, 3],
|
| 99 |
+
|
| 100 |
+
###########################################################
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
SUPERVISED_FEATURES = [
|
| 104 |
+
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
UNSUPERVISED_FEATURES = [
|
| 108 |
+
|
| 109 |
+
]
|
PhysDock/data/alignment_runner.py
ADDED
|
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os.path
|
| 3 |
+
import shutil
|
| 4 |
+
from functools import partial
|
| 5 |
+
import tqdm
|
| 6 |
+
from typing import Optional, Mapping, Any, Union
|
| 7 |
+
|
| 8 |
+
from PhysDock.data.tools import jackhmmer, nhmmer, hhblits, kalign, hmmalign, parsers, hmmbuild, hhsearch, templates
|
| 9 |
+
from PhysDock.utils.io_utils import load_pkl, load_txt, load_json, run_pool_tasks, convert_md5_string, dump_pkl
|
| 10 |
+
from PhysDock.data.tools.parsers import parse_fasta
|
| 11 |
+
|
| 12 |
+
TemplateSearcher = Union[hhsearch.HHSearch]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AlignmentRunner:
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
# Homo Search Tools
|
| 19 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 20 |
+
hhblits_binary_path: Optional[str] = None,
|
| 21 |
+
nhmmer_binary_path: Optional[str] = None,
|
| 22 |
+
hmmbuild_binary_path: Optional[str] = None,
|
| 23 |
+
hmmalign_binary_path: Optional[str] = None,
|
| 24 |
+
kalign_binary_path: Optional[str] = None,
|
| 25 |
+
|
| 26 |
+
# Templ Search Tools
|
| 27 |
+
hhsearch_binary_path: Optional[str] = None,
|
| 28 |
+
template_searcher: Optional[TemplateSearcher] = None,
|
| 29 |
+
template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
|
| 30 |
+
|
| 31 |
+
# Databases
|
| 32 |
+
uniref90_database_path: Optional[str] = None,
|
| 33 |
+
uniprot_database_path: Optional[str] = None,
|
| 34 |
+
uniclust30_database_path: Optional[str] = None,
|
| 35 |
+
uniref30_database_path: Optional[str] = None,
|
| 36 |
+
bfd_database_path: Optional[str] = None,
|
| 37 |
+
reduced_bfd_database_path: Optional[str] = None,
|
| 38 |
+
mgnify_database_path: Optional[str] = None,
|
| 39 |
+
rfam_database_path: Optional[str] = None,
|
| 40 |
+
rnacentral_database_path: Optional[str] = None,
|
| 41 |
+
nt_database_path: Optional[str] = None,
|
| 42 |
+
#
|
| 43 |
+
no_cpus: int = 8,
|
| 44 |
+
# Limitations
|
| 45 |
+
uniref90_seq_limit: int = 100000,
|
| 46 |
+
uniprot_seq_limit: int = 500000,
|
| 47 |
+
reduced_bfd_seq_limit: int = 50000,
|
| 48 |
+
mgnify_seq_limit: int = 50000,
|
| 49 |
+
uniref90_max_hits: int = 10000,
|
| 50 |
+
uniprot_max_hits: int = 50000,
|
| 51 |
+
reduced_bfd_max_hits: int = 5000,
|
| 52 |
+
mgnify_max_hits: int = 5000,
|
| 53 |
+
rfam_max_hits: int = 10000,
|
| 54 |
+
rnacentral_max_hits: int = 10000,
|
| 55 |
+
nt_max_hits: int = 10000,
|
| 56 |
+
):
|
| 57 |
+
self.uniref90_jackhmmer_runner = None
|
| 58 |
+
self.uniprot_jackhmmer_runner = None
|
| 59 |
+
self.reduced_bfd_jackhmmer_runner = None
|
| 60 |
+
self.mgnify_jackhmmer_runner = None
|
| 61 |
+
self.bfd_uniref30_hhblits_runner = None
|
| 62 |
+
self.bfd_uniclust30_hhblits_runner = None
|
| 63 |
+
self.rfam_nhmmer_runner = None
|
| 64 |
+
self.rnacentral_nhmmer_runner = None
|
| 65 |
+
self.nt_nhmmer_runner = None
|
| 66 |
+
self.rna_realign_runner = None
|
| 67 |
+
self.template_searcher = template_searcher
|
| 68 |
+
self.template_featurizer = template_featurizer
|
| 69 |
+
|
| 70 |
+
def _all_exists(*objs, hhblits_mode=False):
|
| 71 |
+
if not hhblits_mode:
|
| 72 |
+
for obj in objs:
|
| 73 |
+
if obj is None or not os.path.exists(obj):
|
| 74 |
+
return False
|
| 75 |
+
else:
|
| 76 |
+
for obj in objs:
|
| 77 |
+
if obj is None or not os.path.exists(os.path.split(obj)[0]):
|
| 78 |
+
return False
|
| 79 |
+
return True
|
| 80 |
+
|
| 81 |
+
def _run_msa_tool(
|
| 82 |
+
fasta_path: str,
|
| 83 |
+
msa_out_path: str,
|
| 84 |
+
msa_runner,
|
| 85 |
+
msa_format: str,
|
| 86 |
+
max_sto_sequences: Optional[int] = None,
|
| 87 |
+
) -> Mapping[str, Any]:
|
| 88 |
+
"""Runs an MSA tool, checking if output already exists first."""
|
| 89 |
+
if (msa_format == "sto" and max_sto_sequences is not None):
|
| 90 |
+
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
|
| 91 |
+
else:
|
| 92 |
+
result = msa_runner.query(fasta_path)[0]
|
| 93 |
+
|
| 94 |
+
assert msa_out_path.split('.')[-1] == msa_format
|
| 95 |
+
with open(msa_out_path, "w") as f:
|
| 96 |
+
f.write(result[msa_format])
|
| 97 |
+
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
def _run_rna_realign_tool(
|
| 101 |
+
fasta_path: str,
|
| 102 |
+
msa_in_path: str,
|
| 103 |
+
msa_out_path: str,
|
| 104 |
+
use_precompute=True,
|
| 105 |
+
):
|
| 106 |
+
runner = hmmalign.Hmmalign(
|
| 107 |
+
hmmbuild_binary_path=hmmbuild_binary_path,
|
| 108 |
+
hmmalign_binary_path=hmmalign_binary_path,
|
| 109 |
+
)
|
| 110 |
+
if os.path.exists(msa_in_path) and os.path.getsize(msa_in_path) == 0:
|
| 111 |
+
# print("MSA sto file is 0")
|
| 112 |
+
with open(msa_out_path, "w") as f:
|
| 113 |
+
pass
|
| 114 |
+
return
|
| 115 |
+
if use_precompute:
|
| 116 |
+
if os.path.exists(msa_in_path) and os.path.exists(msa_out_path):
|
| 117 |
+
if os.path.getsize(msa_in_path) > 0 and os.path.getsize(msa_out_path) == 0:
|
| 118 |
+
logging.warning(f"The msa realign file size is zero but the origin file size is over 0! "
|
| 119 |
+
f"fasta: {fasta_path} msa_in_file: {msa_in_path}")
|
| 120 |
+
runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
|
| 121 |
+
else:
|
| 122 |
+
runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
|
| 123 |
+
else:
|
| 124 |
+
runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
|
| 125 |
+
# with open(msa_out_path, "w") as f:
|
| 126 |
+
# f.write(msa_out)
|
| 127 |
+
|
| 128 |
+
assert uniclust30_database_path is None or uniref30_database_path is None, "Only one used"
|
| 129 |
+
|
| 130 |
+
# Jackhmmer
|
| 131 |
+
if _all_exists(jackhmmer_binary_path, uniref90_database_path):
|
| 132 |
+
self.uniref90_jackhmmer_runner = partial(
|
| 133 |
+
_run_msa_tool,
|
| 134 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 135 |
+
binary_path=jackhmmer_binary_path,
|
| 136 |
+
database_path=uniref90_database_path,
|
| 137 |
+
seq_limit=uniref90_seq_limit,
|
| 138 |
+
n_cpu=no_cpus,
|
| 139 |
+
),
|
| 140 |
+
msa_format="sto",
|
| 141 |
+
max_sto_sequences=uniref90_max_hits
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if _all_exists(jackhmmer_binary_path, uniprot_database_path):
|
| 145 |
+
self.uniprot_jackhmmer_runner = partial(
|
| 146 |
+
_run_msa_tool,
|
| 147 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 148 |
+
binary_path=jackhmmer_binary_path,
|
| 149 |
+
database_path=uniprot_database_path,
|
| 150 |
+
seq_limit=uniprot_seq_limit,
|
| 151 |
+
n_cpu=no_cpus,
|
| 152 |
+
),
|
| 153 |
+
msa_format="sto",
|
| 154 |
+
max_sto_sequences=uniprot_max_hits
|
| 155 |
+
)
|
| 156 |
+
if _all_exists(jackhmmer_binary_path, reduced_bfd_database_path):
|
| 157 |
+
self.reduced_bfd_jackhmmer_runner = partial(
|
| 158 |
+
_run_msa_tool,
|
| 159 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 160 |
+
binary_path=jackhmmer_binary_path,
|
| 161 |
+
database_path=reduced_bfd_database_path,
|
| 162 |
+
seq_limit=reduced_bfd_seq_limit,
|
| 163 |
+
n_cpu=no_cpus,
|
| 164 |
+
),
|
| 165 |
+
msa_format="sto",
|
| 166 |
+
max_sto_sequences=reduced_bfd_max_hits
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if _all_exists(jackhmmer_binary_path, mgnify_database_path):
|
| 170 |
+
self.mgnify_jackhmmer_runner = partial(
|
| 171 |
+
_run_msa_tool,
|
| 172 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 173 |
+
binary_path=jackhmmer_binary_path,
|
| 174 |
+
database_path=mgnify_database_path,
|
| 175 |
+
seq_limit=mgnify_seq_limit,
|
| 176 |
+
n_cpu=no_cpus,
|
| 177 |
+
),
|
| 178 |
+
msa_format="sto",
|
| 179 |
+
max_sto_sequences=mgnify_max_hits
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# HHblits
|
| 183 |
+
if _all_exists(hhblits_binary_path, bfd_database_path, uniref30_database_path, hhblits_mode=True):
|
| 184 |
+
self.bfd_uniref30_hhblits_runner = partial(
|
| 185 |
+
_run_msa_tool,
|
| 186 |
+
msa_runner=hhblits.HHBlits(
|
| 187 |
+
binary_path=hhblits_binary_path,
|
| 188 |
+
databases=[bfd_database_path, uniref30_database_path],
|
| 189 |
+
n_cpu=no_cpus,
|
| 190 |
+
),
|
| 191 |
+
msa_format="a3m",
|
| 192 |
+
)
|
| 193 |
+
elif _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
|
| 194 |
+
self.bfd_uniclust30_hhblits_runner = partial(
|
| 195 |
+
_run_msa_tool,
|
| 196 |
+
msa_runner=hhblits.HHBlits(
|
| 197 |
+
binary_path=hhblits_binary_path,
|
| 198 |
+
databases=[bfd_database_path, uniclust30_database_path],
|
| 199 |
+
n_cpu=no_cpus,
|
| 200 |
+
),
|
| 201 |
+
msa_format="a3m",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Nhmmer
|
| 205 |
+
if _all_exists(nhmmer_binary_path, rfam_database_path):
|
| 206 |
+
self.rfam_nhmmer_runner = partial(
|
| 207 |
+
_run_msa_tool,
|
| 208 |
+
msa_runner=nhmmer.Nhmmer(
|
| 209 |
+
binary_path=nhmmer_binary_path,
|
| 210 |
+
database_path=rfam_database_path,
|
| 211 |
+
n_cpu=no_cpus
|
| 212 |
+
),
|
| 213 |
+
msa_format="sto",
|
| 214 |
+
max_sto_sequences=rfam_max_hits
|
| 215 |
+
)
|
| 216 |
+
if _all_exists(nhmmer_binary_path, rnacentral_database_path):
|
| 217 |
+
self.rnacentral_nhmmer_runner = partial(
|
| 218 |
+
_run_msa_tool,
|
| 219 |
+
msa_runner=nhmmer.Nhmmer(
|
| 220 |
+
binary_path=nhmmer_binary_path,
|
| 221 |
+
database_path=rnacentral_database_path,
|
| 222 |
+
n_cpu=no_cpus
|
| 223 |
+
),
|
| 224 |
+
msa_format="sto",
|
| 225 |
+
max_sto_sequences=rnacentral_max_hits
|
| 226 |
+
)
|
| 227 |
+
if _all_exists(nhmmer_binary_path, nt_database_path):
|
| 228 |
+
self.nt_nhmmer_runner = partial(
|
| 229 |
+
_run_msa_tool,
|
| 230 |
+
msa_runner=nhmmer.Nhmmer(
|
| 231 |
+
binary_path=nhmmer_binary_path,
|
| 232 |
+
database_path=nt_database_path,
|
| 233 |
+
n_cpu=no_cpus
|
| 234 |
+
),
|
| 235 |
+
msa_format="sto",
|
| 236 |
+
max_sto_sequences=nt_max_hits
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# def _run_rna_hmm(
|
| 240 |
+
# fasta_path: str,
|
| 241 |
+
# hmm_out_path: str,
|
| 242 |
+
# ):
|
| 243 |
+
# runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
|
| 244 |
+
# hmm = runner.build_rna_profile_from_fasta(fasta_path)
|
| 245 |
+
# with open(hmm_out_path, "w") as f:
|
| 246 |
+
# f.write(hmm)
|
| 247 |
+
|
| 248 |
+
if _all_exists(hmmbuild_binary_path, hmmalign_binary_path):
|
| 249 |
+
self.rna_realign_runner = _run_rna_realign_tool
|
| 250 |
+
|
| 251 |
+
def run(self, input_fasta_path, output_msas_dir, use_precompute=True):
|
| 252 |
+
os.makedirs(output_msas_dir, exist_ok=True)
|
| 253 |
+
templates_out_path = os.path.join(output_msas_dir, "templates")
|
| 254 |
+
uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
|
| 255 |
+
uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
|
| 256 |
+
reduced_bfd_out_path = os.path.join(output_msas_dir, "reduced_bfd_hits.sto")
|
| 257 |
+
mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
|
| 258 |
+
bfd_uniref30_out_path = os.path.join(output_msas_dir, f"bfd_uniref30_hits.a3m")
|
| 259 |
+
bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
|
| 260 |
+
rfam_out_path = os.path.join(output_msas_dir, f"rfam_hits.sto")
|
| 261 |
+
rfam_out_realigned_path = os.path.join(output_msas_dir, f"rfam_hits_realigned.sto")
|
| 262 |
+
rnacentral_out_path = os.path.join(output_msas_dir, f"rnacentral_hits.sto")
|
| 263 |
+
rnacentral_out_realigned_path = os.path.join(output_msas_dir, f"rnacentral_hits_realigned.sto")
|
| 264 |
+
nt_out_path = os.path.join(output_msas_dir, f"nt_hits.sto")
|
| 265 |
+
nt_out_realigned_path = os.path.join(output_msas_dir, f"nt_hits_realigned.sto")
|
| 266 |
+
|
| 267 |
+
seqs, decs = parse_fasta(load_txt(input_fasta_path))
|
| 268 |
+
prefix = "protein"
|
| 269 |
+
md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 270 |
+
output_feature = os.path.dirname(output_msas_dir)
|
| 271 |
+
output_feature = os.path.dirname(output_feature)
|
| 272 |
+
pkl_save_path_msa = os.path.join(output_feature, "msa_features", f"{md5}.pkl.gz")
|
| 273 |
+
pkl_save_path_msa_uni = os.path.join(output_feature, "uniprot_msa_features", f"{md5}.pkl.gz")
|
| 274 |
+
pkl_save_path_temp = os.path.join(output_feature, "template_features", f"{md5}.pkl.gz")
|
| 275 |
+
|
| 276 |
+
if self.uniref90_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_temp):
|
| 277 |
+
if not os.path.exists(uniref90_out_path) or not use_precompute or not os.path.exists(templates_out_path):
|
| 278 |
+
# print(input_fasta_path, uniref90_out_path)
|
| 279 |
+
if not os.path.exists(uniref90_out_path):
|
| 280 |
+
print(uniref90_out_path)
|
| 281 |
+
self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
|
| 282 |
+
|
| 283 |
+
print("begin templates")
|
| 284 |
+
if templates_out_path is not None:
|
| 285 |
+
try:
|
| 286 |
+
os.makedirs(templates_out_path, exist_ok=True)
|
| 287 |
+
seq, dec = parsers.parse_fasta(load_txt(input_fasta_path))
|
| 288 |
+
input_sequence = seq[0]
|
| 289 |
+
# msa_for_templates = jackhmmer_uniref90_result["sto"]
|
| 290 |
+
msa_for_templates = parsers.truncate_stockholm_msa(
|
| 291 |
+
uniref90_out_path, max_sequences=10000
|
| 292 |
+
)
|
| 293 |
+
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
|
| 294 |
+
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
|
| 295 |
+
msa_for_templates
|
| 296 |
+
)
|
| 297 |
+
if self.template_searcher.input_format == "sto":
|
| 298 |
+
pdb_templates_result = self.template_searcher.query(msa_for_templates)
|
| 299 |
+
elif self.template_searcher.input_format == "a3m":
|
| 300 |
+
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
|
| 301 |
+
pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
|
| 302 |
+
else:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
"Unrecognized template input format: "
|
| 305 |
+
f"{self.template_searcher.input_format}"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
pdb_hits_out_path = os.path.join(
|
| 309 |
+
templates_out_path, f"pdb_hits.{self.template_searcher.output_format}.pkl.gz"
|
| 310 |
+
)
|
| 311 |
+
with open(os.path.join(
|
| 312 |
+
templates_out_path, f"pdb_hits.{self.template_searcher.output_format}"
|
| 313 |
+
), "w") as f:
|
| 314 |
+
f.write(pdb_templates_result)
|
| 315 |
+
|
| 316 |
+
pdb_template_hits = self.template_searcher.get_template_hits(
|
| 317 |
+
output_string=pdb_templates_result, input_sequence=input_sequence
|
| 318 |
+
)
|
| 319 |
+
templates_result = self.template_featurizer.get_templates(
|
| 320 |
+
query_sequence=input_sequence, hits=pdb_template_hits
|
| 321 |
+
)
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logging.exception("An error in template searching")
|
| 324 |
+
|
| 325 |
+
dump_pkl(templates_result.features, pdb_hits_out_path, compress=True)
|
| 326 |
+
if self.uniprot_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa_uni):
|
| 327 |
+
if not os.path.exists(uniprot_out_path) or not use_precompute:
|
| 328 |
+
self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
|
| 329 |
+
if self.reduced_bfd_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 330 |
+
if not os.path.exists(reduced_bfd_out_path) or not use_precompute:
|
| 331 |
+
self.reduced_bfd_jackhmmer_runner(input_fasta_path, reduced_bfd_out_path)
|
| 332 |
+
if self.mgnify_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 333 |
+
if not os.path.exists(mgnify_out_path) or not use_precompute:
|
| 334 |
+
self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
|
| 335 |
+
if self.bfd_uniref30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 336 |
+
if not os.path.exists(bfd_uniref30_out_path) or not use_precompute:
|
| 337 |
+
self.bfd_uniref30_hhblits_runner(input_fasta_path, bfd_uniref30_out_path)
|
| 338 |
+
if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 339 |
+
if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
|
| 340 |
+
self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
|
| 341 |
+
# if self.rfam_nhmmer_runner is not None:
|
| 342 |
+
# if not os.path.exists(rfam_out_path) or not use_precompute:
|
| 343 |
+
# self.rfam_nhmmer_runner(input_fasta_path, rfam_out_path)
|
| 344 |
+
# # print(self.rna_realign_runner is not None, os.path.exists(rfam_out_path))
|
| 345 |
+
# if self.rna_realign_runner is not None and os.path.exists(rfam_out_path):
|
| 346 |
+
# self.rna_realign_runner(input_fasta_path, rfam_out_path, rfam_out_realigned_path)
|
| 347 |
+
# if self.rnacentral_nhmmer_runner is not None:
|
| 348 |
+
# if not os.path.exists(rnacentral_out_path) or not use_precompute:
|
| 349 |
+
# self.rnacentral_nhmmer_runner(input_fasta_path, rnacentral_out_path)
|
| 350 |
+
# if self.rna_realign_runner is not None and os.path.exists(rnacentral_out_path):
|
| 351 |
+
# self.rna_realign_runner(input_fasta_path, rnacentral_out_path, rnacentral_out_realigned_path)
|
| 352 |
+
# if self.nt_nhmmer_runner is not None:
|
| 353 |
+
# if not os.path.exists(nt_out_path) or not use_precompute:
|
| 354 |
+
# self.nt_nhmmer_runner(input_fasta_path, nt_out_path)
|
| 355 |
+
# if self.rna_realign_runner is not None and os.path.exists(nt_out_path):
|
| 356 |
+
# # print("realign",nt_out_path,nt_out_realigned_path)
|
| 357 |
+
# self.rna_realign_runner(input_fasta_path, nt_out_path, nt_out_realigned_path)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class DataProcessor:
|
| 361 |
+
def __init__(
|
| 362 |
+
self,
|
| 363 |
+
alphafold3_database_path,
|
| 364 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 365 |
+
hhblits_binary_path: Optional[str] = None,
|
| 366 |
+
nhmmer_binary_path: Optional[str] = None,
|
| 367 |
+
kalign_binary_path: Optional[str] = None,
|
| 368 |
+
hmmbuild_binary_path: Optional[str] = None,
|
| 369 |
+
hmmalign_binary_path: Optional[str] = None,
|
| 370 |
+
hhsearch_binary_path: Optional[str] = None,
|
| 371 |
+
template_searcher: Optional[TemplateSearcher] = None,
|
| 372 |
+
template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
|
| 373 |
+
n_cpus: int = 8,
|
| 374 |
+
n_workers: int = 1,
|
| 375 |
+
):
|
| 376 |
+
'''
|
| 377 |
+
Database Versions:
|
| 378 |
+
Training:
|
| 379 |
+
uniref90: v2022_05
|
| 380 |
+
bfd:
|
| 381 |
+
reduces_bfd:
|
| 382 |
+
uniclust30: v2018_08
|
| 383 |
+
uniprot: v2020_05
|
| 384 |
+
mgnify: v2022_05
|
| 385 |
+
rfam: v14.9
|
| 386 |
+
rnacentral: v21.0
|
| 387 |
+
nt: v2023_02_23
|
| 388 |
+
Inference:
|
| 389 |
+
uniref90: v2022_05
|
| 390 |
+
bfd:
|
| 391 |
+
reduces_bfd:
|
| 392 |
+
uniclust30: v2018_08
|
| 393 |
+
uniprot: v2021_04 *
|
| 394 |
+
mgnify: v2022_05
|
| 395 |
+
rfam: v14.9
|
| 396 |
+
rnacentral: v21.0
|
| 397 |
+
nt: v2023_02_23
|
| 398 |
+
Inference Ligand:
|
| 399 |
+
uniref90: v2020_01 *
|
| 400 |
+
bfd:
|
| 401 |
+
reduces_bfd:
|
| 402 |
+
uniclust30: v2018_08
|
| 403 |
+
uniprot: v2020_05
|
| 404 |
+
mgnify: v2018_12 *
|
| 405 |
+
rfam: v14.9
|
| 406 |
+
rnacentral: v21.0
|
| 407 |
+
nt: v2023_02_23
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
alphafold3_database_path: Database dir that contains all alphafold3 databases
|
| 411 |
+
jackhmmer_binary_path:
|
| 412 |
+
hhblits_binary_path:
|
| 413 |
+
nhmmer_binary_path:
|
| 414 |
+
kalign_binary_path:
|
| 415 |
+
hmmaligh_binary_path:
|
| 416 |
+
n_cpus:
|
| 417 |
+
n_workers:
|
| 418 |
+
'''
|
| 419 |
+
self.jackhmmer_binary_path = jackhmmer_binary_path
|
| 420 |
+
self.hhblits_binary_path = hhblits_binary_path
|
| 421 |
+
self.nhmmer_binary_path = nhmmer_binary_path
|
| 422 |
+
self.hmmbuild_binary_path = hmmbuild_binary_path
|
| 423 |
+
self.hmmalign_binary_path = hmmalign_binary_path
|
| 424 |
+
self.hhsearch_binary_path = hhsearch_binary_path
|
| 425 |
+
|
| 426 |
+
self.template_searcher = template_searcher
|
| 427 |
+
self.template_featurizer = template_featurizer
|
| 428 |
+
|
| 429 |
+
self.n_cpus = n_cpus
|
| 430 |
+
self.n_workers = n_workers
|
| 431 |
+
|
| 432 |
+
self.uniref90_database_path = os.path.join(
|
| 433 |
+
alphafold3_database_path, "uniref90", "uniref90.fasta"
|
| 434 |
+
)
|
| 435 |
+
################### TODO: DEBUG
|
| 436 |
+
self.uniprot_database_path = os.path.join(
|
| 437 |
+
alphafold3_database_path, "uniprot", "uniprot.fasta"
|
| 438 |
+
)
|
| 439 |
+
self.bfd_database_path = os.path.join(
|
| 440 |
+
alphafold3_database_path, "bfd", "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
|
| 441 |
+
)
|
| 442 |
+
self.uniclust30_database_path = os.path.join(
|
| 443 |
+
alphafold3_database_path, "uniclust30", "uniclust30_2018_08", "uniclust30_2018_08"
|
| 444 |
+
)
|
| 445 |
+
################### TODO: check alphafold2 multimer uniref30 version
|
| 446 |
+
self.uniref_30_database_path = os.path.join(
|
| 447 |
+
alphafold3_database_path, "uniref30", "v2020_06"
|
| 448 |
+
)
|
| 449 |
+
# self.reduced_bfd_database_path = os.path.join(
|
| 450 |
+
# alphafold3_database_path,"reduced_bfd"
|
| 451 |
+
# )
|
| 452 |
+
|
| 453 |
+
self.mgnify_database_path = os.path.join(
|
| 454 |
+
alphafold3_database_path, "mgnify", "mgnify", "mgy_clusters.fa"
|
| 455 |
+
)
|
| 456 |
+
self.rfam_database_path = os.path.join(
|
| 457 |
+
alphafold3_database_path, "rfam", "v14.9", "Rfam_af3_clustered_rep_seq.fasta"
|
| 458 |
+
)
|
| 459 |
+
self.rnacentral_database_path = os.path.join(
|
| 460 |
+
alphafold3_database_path, "rnacentral", "v21.0", "rnacentral_db_rep_seq.fasta"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
self.nt_database_path = os.path.join(
|
| 464 |
+
# alphafold3_database_path, "nt", "v2023_02_23", "nt_af3_clustered_rep_seq.fasta" # DEBUG
|
| 465 |
+
alphafold3_database_path, "nt", "v2023_02_23", "nt.fasta"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
self.runner_args_map = {
|
| 469 |
+
"uniref90": {
|
| 470 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 471 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 472 |
+
},
|
| 473 |
+
"bfd_uniclust30": {
|
| 474 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 475 |
+
"bfd_database_path": self.bfd_database_path,
|
| 476 |
+
"uniclust30_database_path": self.uniclust30_database_path
|
| 477 |
+
},
|
| 478 |
+
"bfd_uniref30": {
|
| 479 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 480 |
+
"bfd_database_path": self.bfd_database_path,
|
| 481 |
+
"uniref_30_database_path": self.uniref_30_database_path
|
| 482 |
+
},
|
| 483 |
+
|
| 484 |
+
"mgnify": {
|
| 485 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 486 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 487 |
+
},
|
| 488 |
+
"uniprot": {
|
| 489 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 490 |
+
"uniprot_database_path": self.uniprot_database_path,
|
| 491 |
+
},
|
| 492 |
+
###################### RNA ########################
|
| 493 |
+
"rfam": {
|
| 494 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 495 |
+
"rfam_database_path": self.rfam_database_path,
|
| 496 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 497 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 498 |
+
},
|
| 499 |
+
"rnacentral": {
|
| 500 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 501 |
+
"rnacentral_database_path": self.rnacentral_database_path,
|
| 502 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 503 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 504 |
+
},
|
| 505 |
+
"nt": {
|
| 506 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 507 |
+
"nt_database_path": self.nt_database_path,
|
| 508 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 509 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 510 |
+
},
|
| 511 |
+
|
| 512 |
+
###################################################
|
| 513 |
+
"alphafold2": {
|
| 514 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 515 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 516 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 517 |
+
"bfd_database_path": self.bfd_database_path,
|
| 518 |
+
"uniclust30_database_path": self.uniclust30_database_path,
|
| 519 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 520 |
+
},
|
| 521 |
+
"alphafold2_multimer": {
|
| 522 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 523 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 524 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 525 |
+
"bfd_database_path": self.bfd_database_path,
|
| 526 |
+
"uniref_30_database_path": self.uniref_30_database_path,
|
| 527 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 528 |
+
"uniprot_database_path": self.uniprot_database_path,
|
| 529 |
+
},
|
| 530 |
+
"alphafold3": {
|
| 531 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 532 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 533 |
+
"template_searcher": self.template_searcher,
|
| 534 |
+
"template_featurizer": self.template_featurizer,
|
| 535 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 536 |
+
"bfd_database_path": self.bfd_database_path,
|
| 537 |
+
"uniclust30_database_path": self.uniclust30_database_path,
|
| 538 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 539 |
+
"uniprot_database_path": self.uniprot_database_path,
|
| 540 |
+
},
|
| 541 |
+
|
| 542 |
+
"rna": {
|
| 543 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 544 |
+
"rfam_database_path": self.rfam_database_path,
|
| 545 |
+
"rnacentral_database_path": self.rnacentral_database_path,
|
| 546 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 547 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 548 |
+
},
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
def _parse_io_tuples(self, input_fasta_path, output_dir, convert_md5=True, prefix="protein"):
|
| 552 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 553 |
+
if isinstance(input_fasta_path, list):
|
| 554 |
+
input_fasta_paths = input_fasta_path
|
| 555 |
+
elif os.path.isdir(input_fasta_path):
|
| 556 |
+
input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
|
| 557 |
+
elif os.path.isfile(input_fasta_path):
|
| 558 |
+
input_fasta_paths = [input_fasta_path]
|
| 559 |
+
else:
|
| 560 |
+
input_fasta_paths = []
|
| 561 |
+
Exception("Can't parse input fasta path!")
|
| 562 |
+
seqs = [parse_fasta(load_txt(i))[0][0] for i in input_fasta_paths]
|
| 563 |
+
# sequences = [parsers.parse_fasta(load_txt(path))[0][0] for path in input_fasta_paths]
|
| 564 |
+
# TODO: debug
|
| 565 |
+
if convert_md5:
|
| 566 |
+
output_msas_dirs = [os.path.join(output_dir, convert_md5_string(f"{prefix}:{i}")) for i in
|
| 567 |
+
seqs]
|
| 568 |
+
else:
|
| 569 |
+
output_msas_dirs = [os.path.join(output_dir, os.path.split(i)[1].split(".")[0]) for i in input_fasta_paths]
|
| 570 |
+
io_tuples = [(i, o) for i, o in zip(input_fasta_paths, output_msas_dirs)]
|
| 571 |
+
return io_tuples
|
| 572 |
+
|
| 573 |
+
def _process_iotuple(self, io_tuple, msas_type, use_precompute=True):
|
| 574 |
+
i, o = io_tuple
|
| 575 |
+
alignment_runner = AlignmentRunner(
|
| 576 |
+
**self.runner_args_map[msas_type],
|
| 577 |
+
no_cpus=self.n_cpus
|
| 578 |
+
)
|
| 579 |
+
try:
|
| 580 |
+
alignment_runner.run(i, o, use_precompute=use_precompute)
|
| 581 |
+
except:
|
| 582 |
+
logging.warning(f"{i}:{o} task failed!")
|
| 583 |
+
|
| 584 |
+
def process(self, input_fasta_path, output_dir, msas_type="rfam", convert_md5=True, use_precompute=True):
|
| 585 |
+
prefix = "rna" if msas_type in ["rfam", "rnacentral", "nt", "rna"] else "protein"
|
| 586 |
+
io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=convert_md5, prefix=prefix)
|
| 587 |
+
run_pool_tasks(partial(self._process_iotuple, msas_type=msas_type, use_precompute=use_precompute), io_tuples,
|
| 588 |
+
num_workers=self.n_workers,
|
| 589 |
+
return_dict=False)
|
| 590 |
+
|
| 591 |
+
def convert_output_to_md5(self, input_fasta_path, output_dir, md5_output_dir, prefix="protein"):
|
| 592 |
+
io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=False, prefix=prefix)
|
| 593 |
+
io_tuples_md5 = self._parse_io_tuples(input_fasta_path, md5_output_dir, convert_md5=True, prefix=prefix)
|
| 594 |
+
|
| 595 |
+
for io0, io1 in tqdm.tqdm(zip(io_tuples, io_tuples_md5)):
|
| 596 |
+
o, o_md5 = io0[1], io1[1]
|
| 597 |
+
os.system(f"cp -r {os.path.abspath(o)} {os.path.abspath(o_md5)}")
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def run_homo_search(
|
| 601 |
+
out_dir,
|
| 602 |
+
save_dir,
|
| 603 |
+
feature_dir,
|
| 604 |
+
pdb_70_dir,
|
| 605 |
+
template_mmcif_dir,
|
| 606 |
+
max_template_date="2021-09-30",
|
| 607 |
+
obsolete_pdbs_path=None,
|
| 608 |
+
use_precompute=True,
|
| 609 |
+
):
|
| 610 |
+
# save_dir = os.path.join(out_dir,"cache")
|
| 611 |
+
data_processor = DataProcessor(
|
| 612 |
+
alphafold3_database_path=feature_dir,
|
| 613 |
+
# nhmmer_binary_path="/usr/bin/nhmmer",
|
| 614 |
+
jackhmmer_binary_path="/usr/bin/jackhmmer",
|
| 615 |
+
hhblits_binary_path="/usr/bin/hhblits",
|
| 616 |
+
hhsearch_binary_path="/usr/bin/hhsearch",
|
| 617 |
+
template_searcher=hhsearch.HHSearch(
|
| 618 |
+
binary_path="/usr/bin/hhsearch",
|
| 619 |
+
databases=[pdb_70_dir]
|
| 620 |
+
),
|
| 621 |
+
template_featurizer=templates.HhsearchHitFeaturizer(
|
| 622 |
+
mmcif_dir=template_mmcif_dir,
|
| 623 |
+
max_template_date=max_template_date,
|
| 624 |
+
max_hits=20,
|
| 625 |
+
kalign_binary_path="/usr/bin/kalign",
|
| 626 |
+
release_dates_path=None,
|
| 627 |
+
obsolete_pdbs_path=obsolete_pdbs_path,
|
| 628 |
+
),
|
| 629 |
+
n_cpus=32,
|
| 630 |
+
n_workers=12
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
output_dir = os.path.join(out_dir, "features/msas")
|
| 634 |
+
# output_dir = "/2022133002/data/stfold-data-v5/features/msas"
|
| 635 |
+
# output_dir = "/2022133002/data/benchmark/stfold/dta/features/msas"
|
| 636 |
+
# output_dir = "/2022133002/data/benchmark/features/msas"
|
| 637 |
+
|
| 638 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 639 |
+
files = os.listdir(save_dir)
|
| 640 |
+
|
| 641 |
+
files = [os.path.join(save_dir, file) for file in files]
|
| 642 |
+
# files = chunk_lists(files,num_workers=4)[3]
|
| 643 |
+
|
| 644 |
+
try:
|
| 645 |
+
data_processor.process(
|
| 646 |
+
input_fasta_path=files,
|
| 647 |
+
output_dir=output_dir,
|
| 648 |
+
msas_type="alphafold3",
|
| 649 |
+
convert_md5=True,
|
| 650 |
+
use_precompute=use_precompute
|
| 651 |
+
)
|
| 652 |
+
print(f"save msa to {output_dir}")
|
| 653 |
+
except Exception as e:
|
| 654 |
+
print(e)
|
| 655 |
+
pass
|
| 656 |
+
|
| 657 |
+
# msa_dir = "/2022133002/data/stfold-data-v5/features/msa_features"
|
| 658 |
+
msa_dir = os.path.join(out_dir, "features/msa_features")
|
| 659 |
+
os.makedirs(msa_dir, exist_ok=True)
|
| 660 |
+
from PhysDock.data.tools.dataset_manager import DatasetManager
|
| 661 |
+
from PhysDock.data.tools.convert_unifold_template_to_stfold import \
|
| 662 |
+
convert_unifold_template_feature_to_stfold_unifold_feature
|
| 663 |
+
|
| 664 |
+
try:
|
| 665 |
+
out = DatasetManager.convert_msas_out_to_msa_features(
|
| 666 |
+
input_fasta_path=save_dir,
|
| 667 |
+
output_dir=output_dir,
|
| 668 |
+
msa_feature_dir=msa_dir,
|
| 669 |
+
convert_md5=True,
|
| 670 |
+
num_workers=2
|
| 671 |
+
)
|
| 672 |
+
print(f"save msa feature to {msa_dir}")
|
| 673 |
+
except:
|
| 674 |
+
pass
|
| 675 |
+
|
| 676 |
+
try:
|
| 677 |
+
msa_dir_uni = os.path.join(out_dir, "features/uniprot_msa_features")
|
| 678 |
+
# msa_dir_uni = "/2022133002/data/stfold-data-v5/features/uniprot_msa_features"
|
| 679 |
+
os.makedirs(msa_dir_uni, exist_ok=True)
|
| 680 |
+
out = DatasetManager.convert_msas_out_to_uniprot_msa_features(
|
| 681 |
+
input_fasta_path=save_dir,
|
| 682 |
+
output_dir=output_dir,
|
| 683 |
+
uniprot_msa_feature_dir=msa_dir_uni,
|
| 684 |
+
convert_md5=True,
|
| 685 |
+
num_workers=2
|
| 686 |
+
)
|
| 687 |
+
print(f"save uni msa feature to {msa_dir_uni}")
|
| 688 |
+
except Exception as e:
|
| 689 |
+
print(e)
|
| 690 |
+
pass
|
| 691 |
+
|
| 692 |
+
templ_dir_uni = os.path.join(out_dir, "features/template_features")
|
| 693 |
+
# templ_dir_uni = "/2022133002/data/stfold-data-v5/features/template_features"
|
| 694 |
+
os.makedirs(templ_dir_uni, exist_ok=True)
|
| 695 |
+
try:
|
| 696 |
+
files = os.listdir(save_dir)
|
| 697 |
+
files = [os.path.join(out_dir, file) for file in files[::-1]]
|
| 698 |
+
run_pool_tasks(convert_unifold_template_feature_to_stfold_unifold_feature, files, num_workers=16)
|
| 699 |
+
except:
|
| 700 |
+
pass
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
class STDockAlignmentRunner():
|
| 704 |
+
def __init__(
|
| 705 |
+
self,
|
| 706 |
+
# Homo Search Tools Path
|
| 707 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 708 |
+
hhblits_binary_path: Optional[str] = None,
|
| 709 |
+
kalign_binary_path: Optional[str] = None,
|
| 710 |
+
hhsearch_binary_path: Optional[str] = None,
|
| 711 |
+
|
| 712 |
+
# Databases
|
| 713 |
+
uniref90_database_path: Optional[str] = None,
|
| 714 |
+
uniprot_database_path: Optional[str] = None,
|
| 715 |
+
uniclust30_database_path: Optional[str] = None,
|
| 716 |
+
bfd_database_path: Optional[str] = None,
|
| 717 |
+
mgnify_database_path: Optional[str] = None,
|
| 718 |
+
pdb_70_database_path: Optional[str] = None,
|
| 719 |
+
mmcif_files_path: Optional[str] = None,
|
| 720 |
+
obsolete_pdbs_path: Optional[str] = None,
|
| 721 |
+
|
| 722 |
+
# Settings
|
| 723 |
+
max_template_date: str = "2021-09-30",
|
| 724 |
+
max_template_hits: int = 20,
|
| 725 |
+
#
|
| 726 |
+
no_cpus: int = 8,
|
| 727 |
+
# Limitations
|
| 728 |
+
uniref90_seq_limit: int = 100000,
|
| 729 |
+
uniprot_seq_limit: int = 500000,
|
| 730 |
+
mgnify_seq_limit: int = 50000,
|
| 731 |
+
uniref90_max_hits: int = 10000,
|
| 732 |
+
uniprot_max_hits: int = 50000,
|
| 733 |
+
mgnify_max_hits: int = 5000,
|
| 734 |
+
):
|
| 735 |
+
super().__init__()
|
| 736 |
+
#
|
| 737 |
+
self.jackhmmer_binary_path = jackhmmer_binary_path
|
| 738 |
+
self.hhblits_binary_path = hhblits_binary_path
|
| 739 |
+
|
| 740 |
+
self.uniref90_database_path = uniref90_database_path
|
| 741 |
+
self.mgnify_database_path = mgnify_database_path
|
| 742 |
+
self.uniclust30_database_path = uniclust30_database_path
|
| 743 |
+
self.bfd_database_path = bfd_database_path
|
| 744 |
+
self.uniprot_database_path = uniprot_database_path
|
| 745 |
+
|
| 746 |
+
self.template_searcher = hhsearch.HHSearch(
|
| 747 |
+
binary_path=hhsearch_binary_path,
|
| 748 |
+
databases=[pdb_70_database_path]
|
| 749 |
+
)
|
| 750 |
+
self.template_featurizer = templates.HhsearchHitFeaturizer(
|
| 751 |
+
mmcif_dir=mmcif_files_path,
|
| 752 |
+
max_template_date=max_template_date,
|
| 753 |
+
max_hits=max_template_hits,
|
| 754 |
+
kalign_binary_path=kalign_binary_path,
|
| 755 |
+
obsolete_pdbs_path=obsolete_pdbs_path,
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
def _all_exists(*objs, hhblits_mode=False):
|
| 759 |
+
if not hhblits_mode:
|
| 760 |
+
for obj in objs:
|
| 761 |
+
if obj is None or not os.path.exists(obj):
|
| 762 |
+
return False
|
| 763 |
+
else:
|
| 764 |
+
for obj in objs:
|
| 765 |
+
if obj is None or not os.path.exists(os.path.split(obj)[0]):
|
| 766 |
+
return False
|
| 767 |
+
return True
|
| 768 |
+
|
| 769 |
+
def _run_msa_tool(
|
| 770 |
+
fasta_path: str,
|
| 771 |
+
msa_out_path: str,
|
| 772 |
+
msa_runner,
|
| 773 |
+
msa_format: str,
|
| 774 |
+
max_sto_sequences: Optional[int] = None,
|
| 775 |
+
) -> Mapping[str, Any]:
|
| 776 |
+
"""Runs an MSA tool, checking if output already exists first."""
|
| 777 |
+
if (msa_format == "sto" and max_sto_sequences is not None):
|
| 778 |
+
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
|
| 779 |
+
else:
|
| 780 |
+
result = msa_runner.query(fasta_path)[0]
|
| 781 |
+
|
| 782 |
+
assert msa_out_path.split('.')[-1] == msa_format
|
| 783 |
+
with open(msa_out_path, "w") as f:
|
| 784 |
+
f.write(result[msa_format])
|
| 785 |
+
|
| 786 |
+
return result
|
| 787 |
+
|
| 788 |
+
# Jackhmmer
|
| 789 |
+
if _all_exists(jackhmmer_binary_path, uniref90_database_path):
|
| 790 |
+
self.uniref90_jackhmmer_runner = partial(
|
| 791 |
+
_run_msa_tool,
|
| 792 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 793 |
+
binary_path=jackhmmer_binary_path,
|
| 794 |
+
database_path=uniref90_database_path,
|
| 795 |
+
seq_limit=uniref90_seq_limit,
|
| 796 |
+
n_cpu=no_cpus,
|
| 797 |
+
),
|
| 798 |
+
msa_format="sto",
|
| 799 |
+
max_sto_sequences=uniref90_max_hits
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
if _all_exists(jackhmmer_binary_path, uniprot_database_path):
|
| 803 |
+
self.uniprot_jackhmmer_runner = partial(
|
| 804 |
+
_run_msa_tool,
|
| 805 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 806 |
+
binary_path=jackhmmer_binary_path,
|
| 807 |
+
database_path=uniprot_database_path,
|
| 808 |
+
seq_limit=uniprot_seq_limit,
|
| 809 |
+
n_cpu=no_cpus,
|
| 810 |
+
),
|
| 811 |
+
msa_format="sto",
|
| 812 |
+
max_sto_sequences=uniprot_max_hits
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
if _all_exists(jackhmmer_binary_path, mgnify_database_path):
|
| 816 |
+
self.mgnify_jackhmmer_runner = partial(
|
| 817 |
+
_run_msa_tool,
|
| 818 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 819 |
+
binary_path=jackhmmer_binary_path,
|
| 820 |
+
database_path=mgnify_database_path,
|
| 821 |
+
seq_limit=mgnify_seq_limit,
|
| 822 |
+
n_cpu=no_cpus,
|
| 823 |
+
),
|
| 824 |
+
msa_format="sto",
|
| 825 |
+
max_sto_sequences=mgnify_max_hits
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
# HHblits
|
| 829 |
+
if _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
|
| 830 |
+
self.bfd_uniclust30_hhblits_runner = partial(
|
| 831 |
+
_run_msa_tool,
|
| 832 |
+
msa_runner=hhblits.HHBlits(
|
| 833 |
+
binary_path=hhblits_binary_path,
|
| 834 |
+
databases=[bfd_database_path, uniclust30_database_path],
|
| 835 |
+
n_cpu=no_cpus,
|
| 836 |
+
),
|
| 837 |
+
msa_format="a3m",
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
def run_protein_msas(
|
| 841 |
+
self,
|
| 842 |
+
input_fasta_path,
|
| 843 |
+
output_msas_dir,
|
| 844 |
+
use_precompute=True,
|
| 845 |
+
copy_to_dataset=False,
|
| 846 |
+
dataset_path=None,
|
| 847 |
+
):
|
| 848 |
+
os.makedirs(output_msas_dir, exist_ok=True)
|
| 849 |
+
os.makedirs(os.path.join(output_msas_dir, "features"), exist_ok=True)
|
| 850 |
+
os.makedirs(os.path.join(output_msas_dir, "features", "msa_features"), exist_ok=True)
|
| 851 |
+
os.makedirs(os.path.join(output_msas_dir, "features", "uniprot_msa_features"), exist_ok=True)
|
| 852 |
+
os.makedirs(os.path.join(output_msas_dir, "features", "template_features"), exist_ok=True)
|
| 853 |
+
|
| 854 |
+
templates_out_path = os.path.join(output_msas_dir, "templates")
|
| 855 |
+
uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
|
| 856 |
+
uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
|
| 857 |
+
mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
|
| 858 |
+
bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
|
| 859 |
+
|
| 860 |
+
seqs, decs = parse_fasta(load_txt(input_fasta_path))
|
| 861 |
+
prefix = "protein"
|
| 862 |
+
md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 863 |
+
|
| 864 |
+
msa_md5_save_path = os.path.join(output_msas_dir, "features", "msa_features", f"{md5}.pkl.gz")
|
| 865 |
+
uniprot_msa_md5_save_path = os.path.join(output_msas_dir, "features", "uniprot_msa_features", f"{md5}.pkl.gz")
|
| 866 |
+
template_md5_save_path = os.path.join(output_msas_dir, "features", "template_features", f"{md5}.pkl.gz")
|
| 867 |
+
|
| 868 |
+
if dataset_path is not None:
|
| 869 |
+
dataset_msa_md5_save_path = os.path.join(
|
| 870 |
+
dataset_path, "features", "msa_features", f"{md5}.pkl.gz")
|
| 871 |
+
dataset_uniprot_msa_md5_save_path = os.path.join(
|
| 872 |
+
dataset_path, "features", "uniprot_msa_features", f"{md5}.pkl.gz")
|
| 873 |
+
dataset_template_md5_save_path = os.path.join(
|
| 874 |
+
dataset_path, "features", "template_features", f"{md5}.pkl.gz")
|
| 875 |
+
|
| 876 |
+
if os.path.exists(dataset_msa_md5_save_path):
|
| 877 |
+
shutil.copyfile(dataset_msa_md5_save_path, msa_md5_save_path)
|
| 878 |
+
if os.path.exists(dataset_uniprot_msa_md5_save_path):
|
| 879 |
+
shutil.copyfile(dataset_uniprot_msa_md5_save_path, uniprot_msa_md5_save_path)
|
| 880 |
+
if os.path.exists(dataset_template_md5_save_path):
|
| 881 |
+
shutil.copyfile(dataset_template_md5_save_path, template_md5_save_path)
|
| 882 |
+
|
| 883 |
+
if self.uniref90_jackhmmer_runner is not None and not os.path.exists(template_md5_save_path):
|
| 884 |
+
if not os.path.exists(uniref90_out_path) or not use_precompute or not os.path.exists(templates_out_path):
|
| 885 |
+
if not os.path.exists(uniref90_out_path):
|
| 886 |
+
self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
|
| 887 |
+
if templates_out_path is not None:
|
| 888 |
+
try:
|
| 889 |
+
os.makedirs(templates_out_path, exist_ok=True)
|
| 890 |
+
seq, dec = parsers.parse_fasta(load_txt(input_fasta_path))
|
| 891 |
+
input_sequence = seq[0]
|
| 892 |
+
# msa_for_templates = jackhmmer_uniref90_result["sto"]
|
| 893 |
+
msa_for_templates = parsers.truncate_stockholm_msa(
|
| 894 |
+
uniref90_out_path, max_sequences=10000
|
| 895 |
+
)
|
| 896 |
+
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
|
| 897 |
+
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
|
| 898 |
+
msa_for_templates
|
| 899 |
+
)
|
| 900 |
+
if self.template_searcher.input_format == "sto":
|
| 901 |
+
pdb_templates_result = self.template_searcher.query(msa_for_templates)
|
| 902 |
+
elif self.template_searcher.input_format == "a3m":
|
| 903 |
+
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
|
| 904 |
+
pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
|
| 905 |
+
else:
|
| 906 |
+
raise ValueError(
|
| 907 |
+
"Unrecognized template input format: "
|
| 908 |
+
f"{self.template_searcher.input_format}"
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
pdb_hits_out_path = os.path.join(
|
| 912 |
+
templates_out_path, f"pdb_hits.{self.template_searcher.output_format}.pkl.gz"
|
| 913 |
+
)
|
| 914 |
+
with open(os.path.join(
|
| 915 |
+
templates_out_path, f"pdb_hits.{self.template_searcher.output_format}"
|
| 916 |
+
), "w") as f:
|
| 917 |
+
f.write(pdb_templates_result)
|
| 918 |
+
|
| 919 |
+
pdb_template_hits = self.template_searcher.get_template_hits(
|
| 920 |
+
output_string=pdb_templates_result, input_sequence=input_sequence
|
| 921 |
+
)
|
| 922 |
+
templates_result = self.template_featurizer.get_templates(
|
| 923 |
+
query_sequence=input_sequence, hits=pdb_template_hits
|
| 924 |
+
)
|
| 925 |
+
dump_pkl(templates_result.features, pdb_hits_out_path, compress=True)
|
| 926 |
+
except Exception as e:
|
| 927 |
+
logging.exception("An error in template searching")
|
| 928 |
+
|
| 929 |
+
if self.uniprot_jackhmmer_runner is not None and not os.path.exists(uniprot_msa_md5_save_path):
|
| 930 |
+
if not os.path.exists(uniprot_out_path) or not use_precompute:
|
| 931 |
+
self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
|
| 932 |
+
if self.mgnify_jackhmmer_runner is not None and not os.path.exists(msa_md5_save_path):
|
| 933 |
+
if not os.path.exists(mgnify_out_path) or not use_precompute:
|
| 934 |
+
self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
|
| 935 |
+
if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(msa_md5_save_path):
|
| 936 |
+
if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
|
| 937 |
+
self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
|
PhysDock/data/alignment_runner_v2.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os.path
|
| 3 |
+
import shutil
|
| 4 |
+
from functools import partial
|
| 5 |
+
import tqdm
|
| 6 |
+
from typing import Optional, Mapping, Any, Union
|
| 7 |
+
|
| 8 |
+
from PhysDock.data.tools import jackhmmer, nhmmer, hhblits, kalign, hmmalign, parsers, hmmbuild, hhsearch, templates
|
| 9 |
+
from PhysDock.utils.io_utils import load_pkl, load_txt, load_json, run_pool_tasks, convert_md5_string, dump_pkl
|
| 10 |
+
from PhysDock.data.tools.parsers import parse_fasta
|
| 11 |
+
from PhysDock.data.tools.dataset_manager import DatasetManager
|
| 12 |
+
|
| 13 |
+
TemplateSearcher = Union[hhsearch.HHSearch]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AlignmentRunner:
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
# Databases
|
| 20 |
+
uniref90_database_path: Optional[str] = None,
|
| 21 |
+
uniprot_database_path: Optional[str] = None,
|
| 22 |
+
uniclust30_database_path: Optional[str] = None,
|
| 23 |
+
bfd_database_path: Optional[str] = None,
|
| 24 |
+
mgnify_database_path: Optional[str] = None,
|
| 25 |
+
|
| 26 |
+
# Homo Search Tools
|
| 27 |
+
jackhmmer_binary_path: str = "/usr/bin/jackhmmer",
|
| 28 |
+
hhblits_binary_path: str = "/usr/bin/hhblits",
|
| 29 |
+
|
| 30 |
+
# Params
|
| 31 |
+
no_cpus: int = 8,
|
| 32 |
+
|
| 33 |
+
# Thresholds
|
| 34 |
+
uniref90_seq_limit: int = 100000,
|
| 35 |
+
uniprot_seq_limit: int = 500000,
|
| 36 |
+
mgnify_seq_limit: int = 50000,
|
| 37 |
+
uniref90_max_hits: int = 10000,
|
| 38 |
+
uniprot_max_hits: int = 50000,
|
| 39 |
+
mgnify_max_hits: int = 5000,
|
| 40 |
+
):
|
| 41 |
+
self.uniref90_jackhmmer_runner = None
|
| 42 |
+
self.uniprot_jackhmmer_runner = None
|
| 43 |
+
self.mgnify_jackhmmer_runner = None
|
| 44 |
+
self.bfd_uniref30_hhblits_runner = None
|
| 45 |
+
self.bfd_uniclust30_hhblits_runner = None
|
| 46 |
+
|
| 47 |
+
def _all_exists(*objs, hhblits_mode=False):
|
| 48 |
+
if not hhblits_mode:
|
| 49 |
+
for obj in objs:
|
| 50 |
+
if obj is None or not os.path.exists(obj):
|
| 51 |
+
return False
|
| 52 |
+
else:
|
| 53 |
+
for obj in objs:
|
| 54 |
+
if obj is None or not os.path.exists(os.path.split(obj)[0]):
|
| 55 |
+
return False
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
def _run_msa_tool(
|
| 59 |
+
fasta_path: str,
|
| 60 |
+
msa_out_path: str,
|
| 61 |
+
msa_runner,
|
| 62 |
+
msa_format: str,
|
| 63 |
+
max_sto_sequences: Optional[int] = None,
|
| 64 |
+
) -> Mapping[str, Any]:
|
| 65 |
+
"""Runs an MSA tool, checking if output already exists first."""
|
| 66 |
+
if (msa_format == "sto" and max_sto_sequences is not None):
|
| 67 |
+
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
|
| 68 |
+
else:
|
| 69 |
+
result = msa_runner.query(fasta_path)[0]
|
| 70 |
+
|
| 71 |
+
assert msa_out_path.split('.')[-1] == msa_format
|
| 72 |
+
with open(msa_out_path, "w") as f:
|
| 73 |
+
f.write(result[msa_format])
|
| 74 |
+
|
| 75 |
+
return result
|
| 76 |
+
|
| 77 |
+
# Jackhmmer
|
| 78 |
+
if _all_exists(jackhmmer_binary_path, uniref90_database_path):
|
| 79 |
+
self.uniref90_jackhmmer_runner = partial(
|
| 80 |
+
_run_msa_tool,
|
| 81 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 82 |
+
binary_path=jackhmmer_binary_path,
|
| 83 |
+
database_path=uniref90_database_path,
|
| 84 |
+
seq_limit=uniref90_seq_limit,
|
| 85 |
+
n_cpu=no_cpus,
|
| 86 |
+
),
|
| 87 |
+
msa_format="sto",
|
| 88 |
+
max_sto_sequences=uniref90_max_hits
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if _all_exists(jackhmmer_binary_path, uniprot_database_path):
|
| 92 |
+
self.uniprot_jackhmmer_runner = partial(
|
| 93 |
+
_run_msa_tool,
|
| 94 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 95 |
+
binary_path=jackhmmer_binary_path,
|
| 96 |
+
database_path=uniprot_database_path,
|
| 97 |
+
seq_limit=uniprot_seq_limit,
|
| 98 |
+
n_cpu=no_cpus,
|
| 99 |
+
),
|
| 100 |
+
msa_format="sto",
|
| 101 |
+
max_sto_sequences=uniprot_max_hits
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if _all_exists(jackhmmer_binary_path, mgnify_database_path):
|
| 105 |
+
self.mgnify_jackhmmer_runner = partial(
|
| 106 |
+
_run_msa_tool,
|
| 107 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 108 |
+
binary_path=jackhmmer_binary_path,
|
| 109 |
+
database_path=mgnify_database_path,
|
| 110 |
+
seq_limit=mgnify_seq_limit,
|
| 111 |
+
n_cpu=no_cpus,
|
| 112 |
+
),
|
| 113 |
+
msa_format="sto",
|
| 114 |
+
max_sto_sequences=mgnify_max_hits
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# HHblits
|
| 118 |
+
if _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
|
| 119 |
+
self.bfd_uniclust30_hhblits_runner = partial(
|
| 120 |
+
_run_msa_tool,
|
| 121 |
+
msa_runner=hhblits.HHBlits(
|
| 122 |
+
binary_path=hhblits_binary_path,
|
| 123 |
+
databases=[bfd_database_path, uniclust30_database_path],
|
| 124 |
+
n_cpu=no_cpus,
|
| 125 |
+
),
|
| 126 |
+
msa_format="a3m",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def run(self, input_fasta_path, output_msas_dir, use_precompute=True):
|
| 130 |
+
os.makedirs(output_msas_dir, exist_ok=True)
|
| 131 |
+
uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
|
| 132 |
+
uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
|
| 133 |
+
mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
|
| 134 |
+
bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
|
| 135 |
+
|
| 136 |
+
seqs, decs = parse_fasta(load_txt(input_fasta_path))
|
| 137 |
+
prefix = "protein"
|
| 138 |
+
md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 139 |
+
output_feature = os.path.dirname(output_msas_dir)
|
| 140 |
+
output_feature = os.path.dirname(output_feature)
|
| 141 |
+
|
| 142 |
+
pkl_save_path_msa = os.path.join(output_feature, "msa_features", f"{md5}.pkl.gz")
|
| 143 |
+
pkl_save_path_msa_uni = os.path.join(output_feature, "uniprot_msa_features", f"{md5}.pkl.gz")
|
| 144 |
+
|
| 145 |
+
if self.uniref90_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 146 |
+
if not os.path.exists(uniref90_out_path) or not use_precompute:
|
| 147 |
+
self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
|
| 148 |
+
|
| 149 |
+
if self.uniprot_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa_uni):
|
| 150 |
+
if not os.path.exists(uniprot_out_path) or not use_precompute:
|
| 151 |
+
self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
|
| 152 |
+
if self.mgnify_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 153 |
+
if not os.path.exists(mgnify_out_path) or not use_precompute:
|
| 154 |
+
self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
|
| 155 |
+
if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 156 |
+
if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
|
| 157 |
+
self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class DataProcessor:
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
bfd_database_path,
|
| 164 |
+
uniclust30_database_path,
|
| 165 |
+
uniref90_database_path,
|
| 166 |
+
mgnify_database_path,
|
| 167 |
+
uniprot_database_path,
|
| 168 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 169 |
+
hhblits_binary_path: Optional[str] = None,
|
| 170 |
+
|
| 171 |
+
n_cpus: int = 8,
|
| 172 |
+
n_workers: int = 1,
|
| 173 |
+
):
|
| 174 |
+
'''
|
| 175 |
+
'''
|
| 176 |
+
self.jackhmmer_binary_path = jackhmmer_binary_path
|
| 177 |
+
self.hhblits_binary_path = hhblits_binary_path
|
| 178 |
+
|
| 179 |
+
self.n_cpus = n_cpus
|
| 180 |
+
self.n_workers = n_workers
|
| 181 |
+
|
| 182 |
+
self.uniref90_database_path = uniref90_database_path
|
| 183 |
+
self.uniprot_database_path = uniprot_database_path
|
| 184 |
+
self.bfd_database_path = bfd_database_path
|
| 185 |
+
self.uniclust30_database_path = uniclust30_database_path
|
| 186 |
+
self.mgnify_database_path = mgnify_database_path
|
| 187 |
+
|
| 188 |
+
# self.uniref90_database_path = os.path.join(
|
| 189 |
+
# alphafold3_database_path, "uniref90", "uniref90.fasta"
|
| 190 |
+
# )
|
| 191 |
+
# self.uniprot_database_path = os.path.join(
|
| 192 |
+
# alphafold3_database_path, "uniprot", "uniprot.fasta"
|
| 193 |
+
# )
|
| 194 |
+
# self.bfd_database_path = os.path.join(
|
| 195 |
+
# alphafold3_database_path, "bfd", "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
|
| 196 |
+
# )
|
| 197 |
+
# self.uniclust30_database_path = os.path.join(
|
| 198 |
+
# alphafold3_database_path, "uniclust30", "uniclust30_2018_08", "uniclust30_2018_08"
|
| 199 |
+
# )
|
| 200 |
+
#
|
| 201 |
+
# self.mgnify_database_path = os.path.join(
|
| 202 |
+
# alphafold3_database_path, "mgnify", "mgnify", "mgy_clusters.fa"
|
| 203 |
+
# )
|
| 204 |
+
|
| 205 |
+
def _parse_io_tuples(self, input_fasta_path, output_dir, convert_md5=True, prefix="protein"):
|
| 206 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 207 |
+
if isinstance(input_fasta_path, list):
|
| 208 |
+
input_fasta_paths = input_fasta_path
|
| 209 |
+
elif os.path.isdir(input_fasta_path):
|
| 210 |
+
input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
|
| 211 |
+
elif os.path.isfile(input_fasta_path):
|
| 212 |
+
input_fasta_paths = [input_fasta_path]
|
| 213 |
+
else:
|
| 214 |
+
input_fasta_paths = []
|
| 215 |
+
Exception("Can't parse input fasta path!")
|
| 216 |
+
seqs = [parse_fasta(load_txt(i))[0][0] for i in input_fasta_paths]
|
| 217 |
+
# sequences = [parsers.parse_fasta(load_txt(path))[0][0] for path in input_fasta_paths]
|
| 218 |
+
# TODO: debug
|
| 219 |
+
if convert_md5:
|
| 220 |
+
output_msas_dirs = [os.path.join(output_dir, convert_md5_string(f"{prefix}:{i}")) for i in
|
| 221 |
+
seqs]
|
| 222 |
+
else:
|
| 223 |
+
output_msas_dirs = [os.path.join(output_dir, os.path.split(i)[1].split(".")[0]) for i in input_fasta_paths]
|
| 224 |
+
io_tuples = [(i, o) for i, o in zip(input_fasta_paths, output_msas_dirs)]
|
| 225 |
+
return io_tuples
|
| 226 |
+
|
| 227 |
+
def _process_iotuple(self, io_tuple, use_precompute=True):
|
| 228 |
+
i, o = io_tuple
|
| 229 |
+
kwargs = {
|
| 230 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 231 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 232 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 233 |
+
"bfd_database_path": self.bfd_database_path,
|
| 234 |
+
"uniclust30_database_path": self.uniclust30_database_path,
|
| 235 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 236 |
+
"uniprot_database_path": self.uniprot_database_path,
|
| 237 |
+
}
|
| 238 |
+
alignment_runner = AlignmentRunner(
|
| 239 |
+
**kwargs,
|
| 240 |
+
no_cpus=self.n_cpus
|
| 241 |
+
)
|
| 242 |
+
try:
|
| 243 |
+
alignment_runner.run(i, o, use_precompute=use_precompute)
|
| 244 |
+
except:
|
| 245 |
+
logging.warning(f"{i}:{o} task failed!")
|
| 246 |
+
|
| 247 |
+
def process(self, input_fasta_path, output_dir, convert_md5=True, use_precompute=True):
|
| 248 |
+
prefix = "protein"
|
| 249 |
+
io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=convert_md5, prefix=prefix)
|
| 250 |
+
run_pool_tasks(partial(self._process_iotuple, use_precompute=use_precompute), io_tuples,
|
| 251 |
+
num_workers=self.n_workers,
|
| 252 |
+
return_dict=False)
|
| 253 |
+
|
| 254 |
+
def convert_output_to_md5(self, input_fasta_path, output_dir, md5_output_dir, prefix="protein"):
|
| 255 |
+
io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=False, prefix=prefix)
|
| 256 |
+
io_tuples_md5 = self._parse_io_tuples(input_fasta_path, md5_output_dir, convert_md5=True, prefix=prefix)
|
| 257 |
+
|
| 258 |
+
for io0, io1 in tqdm.tqdm(zip(io_tuples, io_tuples_md5)):
|
| 259 |
+
o, o_md5 = io0[1], io1[1]
|
| 260 |
+
os.system(f"cp -r {os.path.abspath(o)} {os.path.abspath(o_md5)}")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def run_homo_search(
|
| 264 |
+
bfd_database_path,
|
| 265 |
+
uniclust30_database_path,
|
| 266 |
+
uniref90_database_path,
|
| 267 |
+
mgnify_database_path,
|
| 268 |
+
uniprot_database_path,
|
| 269 |
+
jackhmmer_binary_path,
|
| 270 |
+
hhblits_binary_path,
|
| 271 |
+
|
| 272 |
+
input_fasta_path,
|
| 273 |
+
out_dir,
|
| 274 |
+
|
| 275 |
+
n_cpus=16,
|
| 276 |
+
n_workers=1,
|
| 277 |
+
):
|
| 278 |
+
# save_dir = os.path.join(out_dir,"cache")
|
| 279 |
+
data_processor = DataProcessor(
|
| 280 |
+
bfd_database_path,
|
| 281 |
+
uniclust30_database_path,
|
| 282 |
+
uniref90_database_path,
|
| 283 |
+
mgnify_database_path,
|
| 284 |
+
uniprot_database_path,
|
| 285 |
+
jackhmmer_binary_path=jackhmmer_binary_path,
|
| 286 |
+
hhblits_binary_path=hhblits_binary_path,
|
| 287 |
+
n_cpus=n_cpus,
|
| 288 |
+
n_workers=n_workers
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
output_dir = os.path.join(out_dir, "msas")
|
| 292 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 293 |
+
if os.path.isfile(input_fasta_path):
|
| 294 |
+
files = [input_fasta_path]
|
| 295 |
+
else:
|
| 296 |
+
files = os.listdir(input_fasta_path)
|
| 297 |
+
files = [os.path.join(input_fasta_path, file) for file in files[::-1]]
|
| 298 |
+
|
| 299 |
+
data_processor.process(
|
| 300 |
+
input_fasta_path=files,
|
| 301 |
+
output_dir=output_dir,
|
| 302 |
+
convert_md5=True
|
| 303 |
+
)
|
| 304 |
+
print(f"save msa to {output_dir}")
|
| 305 |
+
|
| 306 |
+
msa_dir = os.path.join(out_dir, "msa_features")
|
| 307 |
+
os.makedirs(msa_dir, exist_ok=True)
|
| 308 |
+
|
| 309 |
+
out = DatasetManager.convert_msas_out_to_msa_features(
|
| 310 |
+
input_fasta_path=input_fasta_path,
|
| 311 |
+
output_dir=output_dir,
|
| 312 |
+
msa_feature_dir=msa_dir,
|
| 313 |
+
convert_md5=True,
|
| 314 |
+
num_workers=2
|
| 315 |
+
)
|
| 316 |
+
print(f"save msa feature to {msa_dir}")
|
| 317 |
+
|
| 318 |
+
msa_dir_uni = os.path.join(out_dir, "uniprot_msa_features")
|
| 319 |
+
os.makedirs(msa_dir_uni, exist_ok=True)
|
| 320 |
+
out = DatasetManager.convert_msas_out_to_uniprot_msa_features(
|
| 321 |
+
input_fasta_path=input_fasta_path,
|
| 322 |
+
output_dir=output_dir,
|
| 323 |
+
uniprot_msa_feature_dir=msa_dir_uni,
|
| 324 |
+
convert_md5=True,
|
| 325 |
+
num_workers=2
|
| 326 |
+
)
|
| 327 |
+
print(f"save uni msa feature to {msa_dir_uni}")
|
PhysDock/data/constants/PDBData.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2000 Andrew Dalke. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This file is part of the Biopython distribution and governed by your
|
| 4 |
+
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
|
| 5 |
+
# Please see the LICENSE file that should have been included as part of this
|
| 6 |
+
# package.
|
| 7 |
+
"""Information about the IUPAC alphabets."""
|
| 8 |
+
|
| 9 |
+
protein_letters = "ACDEFGHIKLMNPQRSTVWY"
|
| 10 |
+
extended_protein_letters = "ACDEFGHIKLMNPQRSTVWYBXZJUO"
|
| 11 |
+
|
| 12 |
+
# B = "Asx"; aspartic acid or asparagine (D or N)
|
| 13 |
+
# X = "Xxx"; unknown or 'other' amino acid
|
| 14 |
+
# Z = "Glx"; glutamic acid or glutamine (E or Q)
|
| 15 |
+
# http://www.chem.qmul.ac.uk/iupac/AminoAcid/A2021.html#AA212
|
| 16 |
+
#
|
| 17 |
+
# J = "Xle"; leucine or isoleucine (L or I, used in NMR)
|
| 18 |
+
# Mentioned in http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
|
| 19 |
+
# Also the International Nucleotide Sequence Database Collaboration (INSDC)
|
| 20 |
+
# (i.e. GenBank, EMBL, DDBJ) adopted this in 2006
|
| 21 |
+
# http://www.ddbj.nig.ac.jp/insdc/icm2006-e.html
|
| 22 |
+
#
|
| 23 |
+
# Xle (J); Leucine or Isoleucine
|
| 24 |
+
# The residue abbreviations, Xle (the three-letter abbreviation) and J
|
| 25 |
+
# (the one-letter abbreviation) are reserved for the case that cannot
|
| 26 |
+
# experimentally distinguish leucine from isoleucine.
|
| 27 |
+
#
|
| 28 |
+
# U = "Sec"; selenocysteine
|
| 29 |
+
# http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
|
| 30 |
+
#
|
| 31 |
+
# O = "Pyl"; pyrrolysine
|
| 32 |
+
# http://www.chem.qmul.ac.uk/iubmb/newsletter/2009.html#item35
|
| 33 |
+
|
| 34 |
+
protein_letters_1to3 = {
|
| 35 |
+
"A": "Ala",
|
| 36 |
+
"C": "Cys",
|
| 37 |
+
"D": "Asp",
|
| 38 |
+
"E": "Glu",
|
| 39 |
+
"F": "Phe",
|
| 40 |
+
"G": "Gly",
|
| 41 |
+
"H": "His",
|
| 42 |
+
"I": "Ile",
|
| 43 |
+
"K": "Lys",
|
| 44 |
+
"L": "Leu",
|
| 45 |
+
"M": "Met",
|
| 46 |
+
"N": "Asn",
|
| 47 |
+
"P": "Pro",
|
| 48 |
+
"Q": "Gln",
|
| 49 |
+
"R": "Arg",
|
| 50 |
+
"S": "Ser",
|
| 51 |
+
"T": "Thr",
|
| 52 |
+
"V": "Val",
|
| 53 |
+
"W": "Trp",
|
| 54 |
+
"Y": "Tyr",
|
| 55 |
+
}
|
| 56 |
+
protein_letters_1to3 = {k.upper(): v.upper() for k, v in protein_letters_1to3.items()}
|
| 57 |
+
protein_letters_3to1 = {v: k for k, v in protein_letters_1to3.items()}
|
| 58 |
+
|
| 59 |
+
protein_letters_3to1_extended = {
|
| 60 |
+
"A5N": "N", "A8E": "V", "A9D": "S", "AA3": "A", "AA4": "A", "AAR": "R",
|
| 61 |
+
"ABA": "A", "ACL": "R", "AEA": "C", "AEI": "D", "AFA": "N", "AGM": "R",
|
| 62 |
+
"AGQ": "Y", "AGT": "C", "AHB": "N", "AHL": "R", "AHO": "A", "AHP": "A",
|
| 63 |
+
"AIB": "A", "AKL": "D", "AKZ": "D", "ALA": "A", "ALC": "A", "ALM": "A",
|
| 64 |
+
"ALN": "A", "ALO": "T", "ALS": "A", "ALT": "A", "ALV": "A", "ALY": "K",
|
| 65 |
+
"AME": "M", "AN6": "L", "AN8": "A", "API": "K", "APK": "K", "AR2": "R",
|
| 66 |
+
"AR4": "E", "AR7": "R", "ARG": "R", "ARM": "R", "ARO": "R", "AS7": "N",
|
| 67 |
+
"ASA": "D", "ASB": "D", "ASI": "D", "ASK": "D", "ASL": "D", "ASN": "N",
|
| 68 |
+
"ASP": "D", "ASQ": "D", "AYA": "A", "AZH": "A", "AZK": "K", "AZS": "S",
|
| 69 |
+
"AZY": "Y", "AVJ": "H", "A30": "Y", "A3U": "F", "ECC": "Q", "ECX": "C",
|
| 70 |
+
"EFC": "C", "EHP": "F", "ELY": "K", "EME": "E", "EPM": "M", "EPQ": "Q",
|
| 71 |
+
"ESB": "Y", "ESC": "M", "EXY": "L", "EXA": "K", "E0Y": "P", "E9V": "H",
|
| 72 |
+
"E9M": "W", "EJA": "C", "EUP": "T", "EZY": "G", "E9C": "Y", "EW6": "S",
|
| 73 |
+
"EXL": "W", "I2M": "I", "I4G": "G", "I58": "K", "IAM": "A", "IAR": "R",
|
| 74 |
+
"ICY": "C", "IEL": "K", "IGL": "G", "IIL": "I", "ILE": "I", "ILG": "E",
|
| 75 |
+
"ILM": "I", "ILX": "I", "ILY": "K", "IML": "I", "IOR": "R", "IPG": "G",
|
| 76 |
+
"IT1": "K", "IYR": "Y", "IZO": "M", "IC0": "G", "M0H": "C", "M2L": "K",
|
| 77 |
+
"M2S": "M", "M30": "G", "M3L": "K", "M3R": "K", "MA ": "A", "MAA": "A",
|
| 78 |
+
"MAI": "R", "MBQ": "Y", "MC1": "S", "MCL": "K", "MCS": "C", "MD3": "C",
|
| 79 |
+
"MD5": "C", "MD6": "G", "MDF": "Y", "ME0": "M", "MEA": "F", "MEG": "E",
|
| 80 |
+
"MEN": "N", "MEQ": "Q", "MET": "M", "MEU": "G", "MFN": "E", "MGG": "R",
|
| 81 |
+
"MGN": "Q", "MGY": "G", "MH1": "H", "MH6": "S", "MHL": "L", "MHO": "M",
|
| 82 |
+
"MHS": "H", "MHU": "F", "MIR": "S", "MIS": "S", "MK8": "L", "ML3": "K",
|
| 83 |
+
"MLE": "L", "MLL": "L", "MLY": "K", "MLZ": "K", "MME": "M", "MMO": "R",
|
| 84 |
+
"MNL": "L", "MNV": "V", "MP8": "P", "MPQ": "G", "MSA": "G", "MSE": "M",
|
| 85 |
+
"MSL": "M", "MSO": "M", "MT2": "M", "MTY": "Y", "MVA": "V", "MYK": "K",
|
| 86 |
+
"MYN": "R", "QCS": "C", "QIL": "I", "QMM": "Q", "QPA": "C", "QPH": "F",
|
| 87 |
+
"Q3P": "K", "QVA": "C", "QX7": "A", "Q2E": "W", "Q75": "M", "Q78": "F",
|
| 88 |
+
"QM8": "L", "QMB": "A", "QNQ": "C", "QNT": "C", "QNW": "C", "QO2": "C",
|
| 89 |
+
"QO5": "C", "QO8": "C", "QQ8": "Q", "U2X": "Y", "U3X": "F", "UF0": "S",
|
| 90 |
+
"UGY": "G", "UM1": "A", "UM2": "A", "UMA": "A", "UQK": "A", "UX8": "W",
|
| 91 |
+
"UXQ": "F", "YCM": "C", "YOF": "Y", "YPR": "P", "YPZ": "Y", "YTH": "T",
|
| 92 |
+
"Y1V": "L", "Y57": "K", "YHA": "K", "200": "F", "23F": "F", "23P": "A",
|
| 93 |
+
"26B": "T", "28X": "T", "2AG": "A", "2CO": "C", "2FM": "M", "2GX": "F",
|
| 94 |
+
"2HF": "H", "2JG": "S", "2KK": "K", "2KP": "K", "2LT": "Y", "2LU": "L",
|
| 95 |
+
"2ML": "L", "2MR": "R", "2MT": "P", "2OR": "R", "2P0": "P", "2QZ": "T",
|
| 96 |
+
"2R3": "Y", "2RA": "A", "2RX": "S", "2SO": "H", "2TY": "Y", "2VA": "V",
|
| 97 |
+
"2XA": "C", "2ZC": "S", "6CL": "K", "6CW": "W", "6GL": "A", "6HN": "K",
|
| 98 |
+
"60F": "C", "66D": "I", "6CV": "A", "6M6": "C", "6V1": "C", "6WK": "C",
|
| 99 |
+
"6Y9": "P", "6DN": "K", "DA2": "R", "DAB": "A", "DAH": "F", "DBS": "S",
|
| 100 |
+
"DBU": "T", "DBY": "Y", "DBZ": "A", "DC2": "C", "DDE": "H", "DDZ": "A",
|
| 101 |
+
"DI7": "Y", "DHA": "S", "DHN": "V", "DIR": "R", "DLS": "K", "DM0": "K",
|
| 102 |
+
"DMH": "N", "DMK": "D", "DNL": "K", "DNP": "A", "DNS": "K", "DNW": "A",
|
| 103 |
+
"DOH": "D", "DON": "L", "DP1": "R", "DPL": "P", "DPP": "A", "DPQ": "Y",
|
| 104 |
+
"DYS": "C", "D2T": "D", "DYA": "D", "DJD": "F", "DYJ": "P", "DV9": "E",
|
| 105 |
+
"H14": "F", "H1D": "M", "H5M": "P", "HAC": "A", "HAR": "R", "HBN": "H",
|
| 106 |
+
"HCM": "C", "HGY": "G", "HHI": "H", "HIA": "H", "HIC": "H", "HIP": "H",
|
| 107 |
+
"HIQ": "H", "HIS": "H", "HL2": "L", "HLU": "L", "HMR": "R", "HNC": "C",
|
| 108 |
+
"HOX": "F", "HPC": "F", "HPE": "F", "HPH": "F", "HPQ": "F", "HQA": "A",
|
| 109 |
+
"HR7": "R", "HRG": "R", "HRP": "W", "HS8": "H", "HS9": "H", "HSE": "S",
|
| 110 |
+
"HSK": "H", "HSL": "S", "HSO": "H", "HT7": "W", "HTI": "C", "HTR": "W",
|
| 111 |
+
"HV5": "A", "HVA": "V", "HY3": "P", "HYI": "M", "HYP": "P", "HZP": "P",
|
| 112 |
+
"HIX": "A", "HSV": "H", "HLY": "K", "HOO": "H", "H7V": "A", "L5P": "K",
|
| 113 |
+
"LRK": "K", "L3O": "L", "LA2": "K", "LAA": "D", "LAL": "A", "LBY": "K",
|
| 114 |
+
"LCK": "K", "LCX": "K", "LDH": "K", "LE1": "V", "LED": "L", "LEF": "L",
|
| 115 |
+
"LEH": "L", "LEM": "L", "LEN": "L", "LET": "K", "LEU": "L", "LEX": "L",
|
| 116 |
+
"LGY": "K", "LLO": "K", "LLP": "K", "LLY": "K", "LLZ": "K", "LME": "E",
|
| 117 |
+
"LMF": "K", "LMQ": "Q", "LNE": "L", "LNM": "L", "LP6": "K", "LPD": "P",
|
| 118 |
+
"LPG": "G", "LPS": "S", "LSO": "K", "LTR": "W", "LVG": "G", "LVN": "V",
|
| 119 |
+
"LWY": "P", "LYF": "K", "LYK": "K", "LYM": "K", "LYN": "K", "LYO": "K",
|
| 120 |
+
"LYP": "K", "LYR": "K", "LYS": "K", "LYU": "K", "LYX": "K", "LYZ": "K",
|
| 121 |
+
"LAY": "L", "LWI": "F", "LBZ": "K", "P1L": "C", "P2Q": "Y", "P2Y": "P",
|
| 122 |
+
"P3Q": "Y", "PAQ": "Y", "PAS": "D", "PAT": "W", "PBB": "C", "PBF": "F",
|
| 123 |
+
"PCA": "Q", "PCC": "P", "PCS": "F", "PE1": "K", "PEC": "C", "PF5": "F",
|
| 124 |
+
"PFF": "F", "PG1": "S", "PGY": "G", "PHA": "F", "PHD": "D", "PHE": "F",
|
| 125 |
+
"PHI": "F", "PHL": "F", "PHM": "F", "PKR": "P", "PLJ": "P", "PM3": "F",
|
| 126 |
+
"POM": "P", "PPN": "F", "PR3": "C", "PR4": "P", "PR7": "P", "PR9": "P",
|
| 127 |
+
"PRJ": "P", "PRK": "K", "PRO": "P", "PRS": "P", "PRV": "G", "PSA": "F",
|
| 128 |
+
"PSH": "H", "PTH": "Y", "PTM": "Y", "PTR": "Y", "PVH": "H", "PXU": "P",
|
| 129 |
+
"PYA": "A", "PYH": "K", "PYX": "C", "PH6": "P", "P9S": "C", "P5U": "S",
|
| 130 |
+
"POK": "R", "T0I": "Y", "T11": "F", "TAV": "D", "TBG": "V", "TBM": "T",
|
| 131 |
+
"TCQ": "Y", "TCR": "W", "TEF": "F", "TFQ": "F", "TH5": "T", "TH6": "T",
|
| 132 |
+
"THC": "T", "THR": "T", "THZ": "R", "TIH": "A", "TIS": "S", "TLY": "K",
|
| 133 |
+
"TMB": "T", "TMD": "T", "TNB": "C", "TNR": "S", "TNY": "T", "TOQ": "W",
|
| 134 |
+
"TOX": "W", "TPJ": "P", "TPK": "P", "TPL": "W", "TPO": "T", "TPQ": "Y",
|
| 135 |
+
"TQI": "W", "TQQ": "W", "TQZ": "C", "TRF": "W", "TRG": "K", "TRN": "W",
|
| 136 |
+
"TRO": "W", "TRP": "W", "TRQ": "W", "TRW": "W", "TRX": "W", "TRY": "W",
|
| 137 |
+
"TS9": "I", "TSY": "C", "TTQ": "W", "TTS": "Y", "TXY": "Y", "TY1": "Y",
|
| 138 |
+
"TY2": "Y", "TY3": "Y", "TY5": "Y", "TY8": "Y", "TY9": "Y", "TYB": "Y",
|
| 139 |
+
"TYC": "Y", "TYE": "Y", "TYI": "Y", "TYJ": "Y", "TYN": "Y", "TYO": "Y",
|
| 140 |
+
"TYQ": "Y", "TYR": "Y", "TYS": "Y", "TYT": "Y", "TYW": "Y", "TYY": "Y",
|
| 141 |
+
"T8L": "T", "T9E": "T", "TNQ": "W", "TSQ": "F", "TGH": "W", "X2W": "E",
|
| 142 |
+
"XCN": "C", "XPR": "P", "XSN": "N", "XW1": "A", "XX1": "K", "XYC": "A",
|
| 143 |
+
"XA6": "F", "11Q": "P", "11W": "E", "12L": "P", "12X": "P", "12Y": "P",
|
| 144 |
+
"143": "C", "1AC": "A", "1L1": "A", "1OP": "Y", "1PA": "F", "1PI": "A",
|
| 145 |
+
"1TQ": "W", "1TY": "Y", "1X6": "S", "56A": "H", "5AB": "A", "5CS": "C",
|
| 146 |
+
"5CW": "W", "5HP": "E", "5OH": "A", "5PG": "G", "51T": "Y", "54C": "W",
|
| 147 |
+
"5CR": "F", "5CT": "K", "5FQ": "A", "5GM": "I", "5JP": "S", "5T3": "K",
|
| 148 |
+
"5MW": "K", "5OW": "K", "5R5": "S", "5VV": "N", "5XU": "A", "55I": "F",
|
| 149 |
+
"999": "D", "9DN": "N", "9NE": "E", "9NF": "F", "9NR": "R", "9NV": "V",
|
| 150 |
+
"9E7": "K", "9KP": "K", "9WV": "A", "9TR": "K", "9TU": "K", "9TX": "K",
|
| 151 |
+
"9U0": "K", "9IJ": "F", "B1F": "F", "B27": "T", "B2A": "A", "B2F": "F",
|
| 152 |
+
"B2I": "I", "B2V": "V", "B3A": "A", "B3D": "D", "B3E": "E", "B3K": "K",
|
| 153 |
+
"B3U": "H", "B3X": "N", "B3Y": "Y", "BB6": "C", "BB7": "C", "BB8": "F",
|
| 154 |
+
"BB9": "C", "BBC": "C", "BCS": "C", "BCX": "C", "BFD": "D", "BG1": "S",
|
| 155 |
+
"BH2": "D", "BHD": "D", "BIF": "F", "BIU": "I", "BL2": "L", "BLE": "L",
|
| 156 |
+
"BLY": "K", "BMT": "T", "BNN": "F", "BOR": "R", "BP5": "A", "BPE": "C",
|
| 157 |
+
"BSE": "S", "BTA": "L", "BTC": "C", "BTK": "K", "BTR": "W", "BUC": "C",
|
| 158 |
+
"BUG": "V", "BYR": "Y", "BWV": "R", "BWB": "S", "BXT": "S", "F2F": "F",
|
| 159 |
+
"F2Y": "Y", "FAK": "K", "FB5": "A", "FB6": "A", "FC0": "F", "FCL": "F",
|
| 160 |
+
"FDL": "K", "FFM": "C", "FGL": "G", "FGP": "S", "FH7": "K", "FHL": "K",
|
| 161 |
+
"FHO": "K", "FIO": "R", "FLA": "A", "FLE": "L", "FLT": "Y", "FME": "M",
|
| 162 |
+
"FOE": "C", "FP9": "P", "FPK": "P", "FT6": "W", "FTR": "W", "FTY": "Y",
|
| 163 |
+
"FVA": "V", "FZN": "K", "FY3": "Y", "F7W": "W", "FY2": "Y", "FQA": "K",
|
| 164 |
+
"F7Q": "Y", "FF9": "K", "FL6": "D", "JJJ": "C", "JJK": "C", "JJL": "C",
|
| 165 |
+
"JLP": "K", "J3D": "C", "J9Y": "R", "J8W": "S", "JKH": "P", "N10": "S",
|
| 166 |
+
"N7P": "P", "NA8": "A", "NAL": "A", "NAM": "A", "NBQ": "Y", "NC1": "S",
|
| 167 |
+
"NCB": "A", "NEM": "H", "NEP": "H", "NFA": "F", "NIY": "Y", "NLB": "L",
|
| 168 |
+
"NLE": "L", "NLN": "L", "NLO": "L", "NLP": "L", "NLQ": "Q", "NLY": "G",
|
| 169 |
+
"NMC": "G", "NMM": "R", "NNH": "R", "NOT": "L", "NPH": "C", "NPI": "A",
|
| 170 |
+
"NTR": "Y", "NTY": "Y", "NVA": "V", "NWD": "A", "NYB": "C", "NYS": "C",
|
| 171 |
+
"NZH": "H", "N80": "P", "NZC": "T", "NLW": "L", "N0A": "F", "N9P": "A",
|
| 172 |
+
"N65": "K", "R1A": "C", "R4K": "W", "RE0": "W", "RE3": "W", "RGL": "R",
|
| 173 |
+
"RGP": "E", "RT0": "P", "RVX": "S", "RZ4": "S", "RPI": "R", "RVJ": "A",
|
| 174 |
+
"VAD": "V", "VAF": "V", "VAH": "V", "VAI": "V", "VAL": "V", "VB1": "K",
|
| 175 |
+
"VH0": "P", "VR0": "R", "V44": "C", "V61": "F", "VPV": "K", "V5N": "H",
|
| 176 |
+
"V7T": "K", "Z01": "A", "Z3E": "T", "Z70": "H", "ZBZ": "C", "ZCL": "F",
|
| 177 |
+
"ZU0": "T", "ZYJ": "P", "ZYK": "P", "ZZD": "C", "ZZJ": "A", "ZIQ": "W",
|
| 178 |
+
"ZPO": "P", "ZDJ": "Y", "ZT1": "K", "30V": "C", "31Q": "C", "33S": "F",
|
| 179 |
+
"33W": "A", "34E": "V", "3AH": "H", "3BY": "P", "3CF": "F", "3CT": "Y",
|
| 180 |
+
"3GA": "A", "3GL": "E", "3MD": "D", "3MY": "Y", "3NF": "Y", "3O3": "E",
|
| 181 |
+
"3PX": "P", "3QN": "K", "3TT": "P", "3XH": "G", "3YM": "Y", "3WS": "A",
|
| 182 |
+
"3WX": "P", "3X9": "C", "3ZH": "H", "7JA": "I", "73C": "S", "73N": "R",
|
| 183 |
+
"73O": "Y", "73P": "K", "74P": "K", "7N8": "F", "7O5": "A", "7XC": "F",
|
| 184 |
+
"7ID": "D", "7OZ": "A", "C1S": "C", "C1T": "C", "C1X": "K", "C22": "A",
|
| 185 |
+
"C3Y": "C", "C4R": "C", "C5C": "C", "C6C": "C", "CAF": "C", "CAS": "C",
|
| 186 |
+
"CAY": "C", "CCS": "C", "CEA": "C", "CGA": "E", "CGU": "E", "CGV": "C",
|
| 187 |
+
"CHP": "G", "CIR": "R", "CLE": "L", "CLG": "K", "CLH": "K", "CME": "C",
|
| 188 |
+
"CMH": "C", "CML": "C", "CMT": "C", "CR5": "G", "CS0": "C", "CS1": "C",
|
| 189 |
+
"CS3": "C", "CS4": "C", "CSA": "C", "CSB": "C", "CSD": "C", "CSE": "C",
|
| 190 |
+
"CSJ": "C", "CSO": "C", "CSP": "C", "CSR": "C", "CSS": "C", "CSU": "C",
|
| 191 |
+
"CSW": "C", "CSX": "C", "CSZ": "C", "CTE": "W", "CTH": "T", "CWD": "A",
|
| 192 |
+
"CWR": "S", "CXM": "M", "CY0": "C", "CY1": "C", "CY3": "C", "CY4": "C",
|
| 193 |
+
"CYA": "C", "CYD": "C", "CYF": "C", "CYG": "C", "CYJ": "K", "CYM": "C",
|
| 194 |
+
"CYQ": "C", "CYR": "C", "CYS": "C", "CYW": "C", "CZ2": "C", "CZZ": "C",
|
| 195 |
+
"CG6": "C", "C1J": "R", "C4G": "R", "C67": "R", "C6D": "R", "CE7": "N",
|
| 196 |
+
"CZS": "A", "G01": "E", "G8M": "E", "GAU": "E", "GEE": "G", "GFT": "S",
|
| 197 |
+
"GHC": "E", "GHG": "Q", "GHW": "E", "GL3": "G", "GLH": "Q", "GLJ": "E",
|
| 198 |
+
"GLK": "E", "GLN": "Q", "GLQ": "E", "GLU": "E", "GLY": "G", "GLZ": "G",
|
| 199 |
+
"GMA": "E", "GME": "E", "GNC": "Q", "GPL": "K", "GSC": "G", "GSU": "E",
|
| 200 |
+
"GT9": "C", "GVL": "S", "G3M": "R", "G5G": "L", "G1X": "Y", "G8X": "P",
|
| 201 |
+
"K1R": "C", "KBE": "K", "KCX": "K", "KFP": "K", "KGC": "K", "KNB": "A",
|
| 202 |
+
"KOR": "M", "KPI": "K", "KPY": "K", "KST": "K", "KYN": "W", "KYQ": "K",
|
| 203 |
+
"KCR": "K", "KPF": "K", "K5L": "S", "KEO": "K", "KHB": "K", "KKD": "D",
|
| 204 |
+
"K5H": "C", "K7K": "S", "OAR": "R", "OAS": "S", "OBS": "K", "OCS": "C",
|
| 205 |
+
"OCY": "C", "OHI": "H", "OHS": "D", "OLD": "H", "OLT": "T", "OLZ": "S",
|
| 206 |
+
"OMH": "S", "OMT": "M", "OMX": "Y", "OMY": "Y", "ONH": "A", "ORN": "A",
|
| 207 |
+
"ORQ": "R", "OSE": "S", "OTH": "T", "OXX": "D", "OYL": "H", "O7A": "T",
|
| 208 |
+
"O7D": "W", "O7G": "V", "O2E": "S", "O6H": "W", "OZW": "F", "S12": "S",
|
| 209 |
+
"S1H": "S", "S2C": "C", "S2P": "A", "SAC": "S", "SAH": "C", "SAR": "G",
|
| 210 |
+
"SBG": "S", "SBL": "S", "SCH": "C", "SCS": "C", "SCY": "C", "SD4": "N",
|
| 211 |
+
"SDB": "S", "SDP": "S", "SEB": "S", "SEE": "S", "SEG": "A", "SEL": "S",
|
| 212 |
+
"SEM": "S", "SEN": "S", "SEP": "S", "SER": "S", "SET": "S", "SGB": "S",
|
| 213 |
+
"SHC": "C", "SHP": "G", "SHR": "K", "SIB": "C", "SLL": "K", "SLZ": "K",
|
| 214 |
+
"SMC": "C", "SME": "M", "SMF": "F", "SNC": "C", "SNN": "N", "SOY": "S",
|
| 215 |
+
"SRZ": "S", "STY": "Y", "SUN": "S", "SVA": "S", "SVV": "S", "SVW": "S",
|
| 216 |
+
"SVX": "S", "SVY": "S", "SVZ": "S", "SXE": "S", "SKH": "K", "SNM": "S",
|
| 217 |
+
"SNK": "H", "SWW": "S", "WFP": "F", "WLU": "L", "WPA": "F", "WRP": "W",
|
| 218 |
+
"WVL": "V", "02K": "A", "02L": "N", "02O": "A", "02Y": "A", "033": "V",
|
| 219 |
+
"037": "P", "03Y": "C", "04U": "P", "04V": "P", "05N": "P", "07O": "C",
|
| 220 |
+
"0A0": "D", "0A1": "Y", "0A2": "K", "0A8": "C", "0A9": "F", "0AA": "V",
|
| 221 |
+
"0AB": "V", "0AC": "G", "0AF": "W", "0AG": "L", "0AH": "S", "0AK": "D",
|
| 222 |
+
"0AR": "R", "0BN": "F", "0CS": "A", "0E5": "T", "0EA": "Y", "0FL": "A",
|
| 223 |
+
"0LF": "P", "0NC": "A", "0PR": "Y", "0QL": "C", "0TD": "D", "0UO": "W",
|
| 224 |
+
"0WZ": "Y", "0X9": "R", "0Y8": "P", "4AF": "F", "4AR": "R", "4AW": "W",
|
| 225 |
+
"4BF": "F", "4CF": "F", "4CY": "M", "4DP": "W", "4FB": "P", "4FW": "W",
|
| 226 |
+
"4HL": "Y", "4HT": "W", "4IN": "W", "4MM": "M", "4PH": "F", "4U7": "A",
|
| 227 |
+
"41H": "F", "41Q": "N", "42Y": "S", "432": "S", "45F": "P", "4AK": "K",
|
| 228 |
+
"4D4": "R", "4GJ": "C", "4KY": "P", "4L0": "P", "4LZ": "Y", "4N7": "P",
|
| 229 |
+
"4N8": "P", "4N9": "P", "4OG": "W", "4OU": "F", "4OV": "S", "4OZ": "S",
|
| 230 |
+
"4PQ": "W", "4SJ": "F", "4WQ": "A", "4HH": "S", "4HJ": "S", "4J4": "C",
|
| 231 |
+
"4J5": "R", "4II": "F", "4VI": "R", "823": "N", "8SP": "S", "8AY": "A",
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
# Nucleic Acids
|
| 235 |
+
nucleic_letters_3to1 = {
|
| 236 |
+
"A ": "A", "C ": "C", "G ": "G", "U ": "U",
|
| 237 |
+
"DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
rna_letters_3to1 = {
|
| 241 |
+
"A ": "A", "C ": "C", "G ": "G", "U ": "U",
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
dna_letters_3to1 = {
|
| 245 |
+
"DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
# fmt: off
|
| 249 |
+
nucleic_letters_3to1_extended = {
|
| 250 |
+
"A ": "A", "A23": "A", "A2L": "A", "A2M": "A", "A34": "A", "A35": "A",
|
| 251 |
+
"A38": "A", "A39": "A", "A3A": "A", "A3P": "A", "A40": "A", "A43": "A",
|
| 252 |
+
"A44": "A", "A47": "A", "A5L": "A", "A5M": "C", "A5O": "A", "A6A": "A",
|
| 253 |
+
"A6C": "C", "A6G": "G", "A6U": "U", "A7E": "A", "A9Z": "A", "ABR": "A",
|
| 254 |
+
"ABS": "A", "AD2": "A", "ADI": "A", "ADP": "A", "AET": "A", "AF2": "A",
|
| 255 |
+
"AFG": "G", "AMD": "A", "AMO": "A", "AP7": "A", "AS ": "A", "ATD": "T",
|
| 256 |
+
"ATL": "T", "ATM": "T", "AVC": "A", "AI5": "C", "E ": "A", "E1X": "A",
|
| 257 |
+
"EDA": "A", "EFG": "G", "EHG": "G", "EIT": "T", "EXC": "C", "E3C": "C",
|
| 258 |
+
"E6G": "G", "E7G": "G", "EQ4": "G", "EAN": "T", "I5C": "C", "IC ": "C",
|
| 259 |
+
"IG ": "G", "IGU": "G", "IMC": "C", "IMP": "G", "IU ": "U", "I4U": "U",
|
| 260 |
+
"IOO": "G", "M1G": "G", "M2G": "G", "M4C": "C", "M5M": "C", "MA6": "A",
|
| 261 |
+
"MA7": "A", "MAD": "A", "MCY": "C", "ME6": "C", "MEP": "U", "MG1": "G",
|
| 262 |
+
"MGQ": "A", "MGT": "G", "MGV": "G", "MIA": "A", "MMT": "T", "MNU": "U",
|
| 263 |
+
"MRG": "G", "MTR": "T", "MTU": "A", "MFO": "G", "M7A": "A", "MHG": "G",
|
| 264 |
+
"MMX": "C", "QUO": "G", "QCK": "T", "QSQ": "A", "U ": "U", "U25": "U",
|
| 265 |
+
"U2L": "U", "U2P": "U", "U31": "U", "U34": "U", "U36": "U", "U37": "U",
|
| 266 |
+
"U8U": "U", "UAR": "U", "UBB": "U", "UBD": "U", "UD5": "U", "UPV": "U",
|
| 267 |
+
"UR3": "U", "URD": "U", "US3": "T", "US5": "U", "UZR": "U", "UMO": "U",
|
| 268 |
+
"U23": "U", "U48": "C", "U7B": "C", "Y ": "A", "YCO": "C", "YG ": "G",
|
| 269 |
+
"YYG": "G", "23G": "G", "26A": "A", "2AR": "A", "2AT": "T", "2AU": "U",
|
| 270 |
+
"2BT": "T", "2BU": "A", "2DA": "A", "2DT": "T", "2EG": "G", "2GT": "T",
|
| 271 |
+
"2JV": "G", "2MA": "A", "2MG": "G", "2MU": "U", "2NT": "T", "2OM": "U",
|
| 272 |
+
"2OT": "T", "2PR": "G", "2SG": "G", "2ST": "T", "63G": "G", "63H": "G",
|
| 273 |
+
"64T": "T", "68Z": "G", "6CT": "T", "6HA": "A", "6HB": "A", "6HC": "C",
|
| 274 |
+
"6HG": "G", "6HT": "T", "6IA": "A", "6MA": "A", "6MC": "A", "6MP": "A",
|
| 275 |
+
"6MT": "A", "6MZ": "A", "6OG": "G", "6PO": "G", "6FK": "G", "6NW": "A",
|
| 276 |
+
"6OO": "C", "D00": "C", "D3T": "T", "D4M": "T", "DA ": "A", "DC ": "C",
|
| 277 |
+
"DCG": "G", "DCT": "C", "DDG": "G", "DFC": "C", "DFG": "G", "DG ": "G",
|
| 278 |
+
"DG8": "G", "DGI": "G", "DGP": "G", "DHU": "U", "DNR": "C", "DOC": "C",
|
| 279 |
+
"DPB": "T", "DRT": "T", "DT ": "T", "DZM": "A", "D4B": "C", "H2U": "U",
|
| 280 |
+
"HN0": "G", "HN1": "G", "LC ": "C", "LCA": "A", "LCG": "G", "LG ": "G",
|
| 281 |
+
"LGP": "G", "LHU": "U", "LSH": "T", "LST": "T", "LDG": "G", "L3X": "A",
|
| 282 |
+
"LHH": "C", "LV2": "C", "L1J": "G", "P ": "G", "P2T": "T", "P5P": "A",
|
| 283 |
+
"PG7": "G", "PGN": "G", "PGP": "G", "PMT": "C", "PPU": "A", "PPW": "G",
|
| 284 |
+
"PR5": "A", "PRN": "A", "PST": "T", "PSU": "U", "PU ": "A", "PVX": "C",
|
| 285 |
+
"PYO": "U", "PZG": "G", "P4U": "U", "P7G": "G", "T ": "T", "T2S": "T",
|
| 286 |
+
"T31": "U", "T32": "T", "T36": "T", "T37": "T", "T38": "T", "T39": "T",
|
| 287 |
+
"T3P": "T", "T41": "T", "T48": "T", "T49": "T", "T4S": "T", "T5S": "T",
|
| 288 |
+
"T64": "T", "T6A": "A", "TA3": "T", "TAF": "T", "TBN": "A", "TC1": "C",
|
| 289 |
+
"TCP": "T", "TCY": "A", "TDY": "T", "TED": "T", "TFE": "T", "TFF": "T",
|
| 290 |
+
"TFO": "A", "TFT": "T", "TGP": "G", "TCJ": "C", "TLC": "T", "TP1": "T",
|
| 291 |
+
"TPC": "C", "TPG": "G", "TSP": "T", "TTD": "T", "TTM": "T", "TXD": "A",
|
| 292 |
+
"TXP": "A", "TC ": "C", "TG ": "G", "T0N": "G", "T0Q": "G", "X ": "G",
|
| 293 |
+
"XAD": "A", "XAL": "A", "XCL": "C", "XCR": "C", "XCT": "C", "XCY": "C",
|
| 294 |
+
"XGL": "G", "XGR": "G", "XGU": "G", "XPB": "G", "XTF": "T", "XTH": "T",
|
| 295 |
+
"XTL": "T", "XTR": "T", "XTS": "G", "XUA": "A", "XUG": "G", "102": "G",
|
| 296 |
+
"10C": "C", "125": "U", "126": "U", "127": "U", "12A": "A", "16B": "C",
|
| 297 |
+
"18M": "G", "1AP": "A", "1CC": "C", "1FC": "C", "1MA": "A", "1MG": "G",
|
| 298 |
+
"1RN": "U", "1SC": "C", "5AA": "A", "5AT": "T", "5BU": "U", "5CG": "G",
|
| 299 |
+
"5CM": "C", "5FA": "A", "5FC": "C", "5FU": "U", "5HC": "C", "5HM": "C",
|
| 300 |
+
"5HT": "T", "5IC": "C", "5IT": "T", "5MC": "C", "5MU": "U", "5NC": "C",
|
| 301 |
+
"5PC": "C", "5PY": "T", "9QV": "U", "94O": "T", "9SI": "A", "9SY": "A",
|
| 302 |
+
"B7C": "C", "BGM": "G", "BOE": "T", "B8H": "U", "B8K": "G", "B8Q": "C",
|
| 303 |
+
"B8T": "C", "B8W": "G", "B9B": "G", "B9H": "C", "BGH": "G", "F3H": "T",
|
| 304 |
+
"F3N": "A", "F4H": "T", "FA2": "A", "FDG": "G", "FHU": "U", "FMG": "G",
|
| 305 |
+
"FNU": "U", "FOX": "G", "F2T": "U", "F74": "G", "F4Q": "G", "F7H": "C",
|
| 306 |
+
"F7K": "G", "JDT": "T", "JMH": "C", "J0X": "C", "N5M": "C", "N6G": "G",
|
| 307 |
+
"N79": "A", "NCU": "C", "NMS": "T", "NMT": "T", "NTT": "T", "N7X": "C",
|
| 308 |
+
"R ": "A", "RBD": "A", "RDG": "G", "RIA": "A", "RMP": "A", "RPC": "C",
|
| 309 |
+
"RSP": "C", "RSQ": "C", "RT ": "T", "RUS": "U", "RFJ": "G", "V3L": "A",
|
| 310 |
+
"VC7": "G", "Z ": "C", "ZAD": "A", "ZBC": "C", "ZBU": "U", "ZCY": "C",
|
| 311 |
+
"ZGU": "G", "31H": "A", "31M": "A", "3AU": "U", "3DA": "A", "3ME": "U",
|
| 312 |
+
"3MU": "U", "3TD": "U", "70U": "U", "7AT": "A", "7DA": "A", "7GU": "G",
|
| 313 |
+
"7MG": "G", "7BG": "G", "73W": "C", "75B": "U", "7OK": "C", "7S3": "G",
|
| 314 |
+
"7SN": "G", "C ": "C", "C25": "C", "C2L": "C", "C2S": "C", "C31": "C",
|
| 315 |
+
"C32": "C", "C34": "C", "C36": "C", "C37": "C", "C38": "C", "C42": "C",
|
| 316 |
+
"C43": "C", "C45": "C", "C46": "C", "C49": "C", "C4S": "C", "C5L": "C",
|
| 317 |
+
"C6G": "G", "CAR": "C", "CB2": "C", "CBR": "C", "CBV": "C", "CCC": "C",
|
| 318 |
+
"CDW": "C", "CFL": "C", "CFZ": "C", "CG1": "G", "CH ": "C", "CMR": "C",
|
| 319 |
+
"CNU": "U", "CP1": "C", "CSF": "C", "CSL": "C", "CTG": "T", "CX2": "C",
|
| 320 |
+
"C7S": "C", "C7R": "C", "G ": "G", "G1G": "G", "G25": "G", "G2L": "G",
|
| 321 |
+
"G2S": "G", "G31": "G", "G32": "G", "G33": "G", "G36": "G", "G38": "G",
|
| 322 |
+
"G42": "G", "G46": "G", "G47": "G", "G48": "G", "G49": "G", "G7M": "G",
|
| 323 |
+
"GAO": "G", "GCK": "C", "GDO": "G", "GDP": "G", "GDR": "G", "GF2": "G",
|
| 324 |
+
"GFL": "G", "GH3": "G", "GMS": "G", "GN7": "G", "GNG": "G", "GOM": "G",
|
| 325 |
+
"GRB": "G", "GS ": "G", "GSR": "G", "GSS": "G", "GTP": "G", "GX1": "G",
|
| 326 |
+
"KAG": "G", "KAK": "G", "O2G": "G", "OGX": "G", "OMC": "C", "OMG": "G",
|
| 327 |
+
"OMU": "U", "ONE": "U", "O2Z": "A", "OKN": "C", "OKQ": "C", "S2M": "T",
|
| 328 |
+
"S4A": "A", "S4C": "C", "S4G": "G", "S4U": "U", "S6G": "G", "SC ": "C",
|
| 329 |
+
"SDE": "A", "SDG": "G", "SDH": "G", "SMP": "A", "SMT": "T", "SPT": "T",
|
| 330 |
+
"SRA": "A", "SSU": "U", "SUR": "U", "00A": "A", "0AD": "G", "0AM": "A",
|
| 331 |
+
"0AP": "C", "0AV": "A", "0R8": "C", "0SP": "A", "0UH": "G", "47C": "C",
|
| 332 |
+
"4OC": "C", "4PC": "C", "4PD": "C", "4PE": "C", "4SC": "C", "4SU": "U",
|
| 333 |
+
"45A": "A", "4U3": "C", "8AG": "G", "8AN": "A", "8BA": "A", "8FG": "G",
|
| 334 |
+
"8MG": "G", "8OG": "G", "8PY": "G", "8AA": "G", "85Y": "U", "8OS": "G",
|
| 335 |
+
"UNK": "X", # DEBUG
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
standard_protein_letters_3to1 = protein_letters_3to1
|
| 339 |
+
standard_protein_letters_1to3 = protein_letters_1to3
|
| 340 |
+
nonstandard_protein_letters_3to1 = {k: v for k, v in protein_letters_3to1_extended.items() if
|
| 341 |
+
k not in standard_protein_letters_3to1}
|
| 342 |
+
|
| 343 |
+
standard_nucleic_letters_3to1 = nucleic_letters_3to1
|
| 344 |
+
standard_nucleic_letters_1to3 = {v: k for k, v in standard_nucleic_letters_3to1.items()}
|
| 345 |
+
nonstandard_nucleic_letters_3to1 = {k: v for k, v in nucleic_letters_3to1_extended.items() if
|
| 346 |
+
k not in standard_nucleic_letters_3to1}
|
| 347 |
+
|
| 348 |
+
letters_3to1_extended = {**protein_letters_3to1_extended, **nucleic_letters_3to1_extended}
|
PhysDock/data/constants/__init__.py
ADDED
|
File without changes
|
PhysDock/data/constants/periodic_table.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
periodic_table = [
|
| 2 |
+
"h", "he",
|
| 3 |
+
"li", "be", "b", "c", "n", "o", "f", "ne",
|
| 4 |
+
"na", "mg", "al", "si", "p", "s", "cl", "ar",
|
| 5 |
+
"k", "ca", "sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn", "ga", "ge", "as", "se", "br", "kr",
|
| 6 |
+
"rb", "sr", "y", "zr", "nb", "mo", "tc", "ru", "rh", "pd", "ag", "cd", "in", "sn", "sb", "te", "i", "xe",
|
| 7 |
+
"cs", "ba",
|
| 8 |
+
"la", "ce", "pr", "nd", "pm", "sm", "eu", "gd", "tb", "dy", "ho", "er", "tm", "yb", "lu",
|
| 9 |
+
"hf", "ta", "w", "re", "os", "ir", "pt", "au", "hg", "tl", "pb", "bi", "po", "at", "rn",
|
| 10 |
+
"fr", "ra",
|
| 11 |
+
"ac", "th", "pa", "u", "np", "pu", "am", "cm", "bk", "cf", "es", "fm", "md", "no", "lr",
|
| 12 |
+
"rf", "db", "sg", "bh", "hs", "mt", "ds", "rg", "cn", "nh", "fl", "mc", "lv", "ts", "og"
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
PeriodicTable = [
|
| 16 |
+
"H", "He",
|
| 17 |
+
"Li", "Be", "B", "C", "N", "O", "F", "Ne",
|
| 18 |
+
"Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar",
|
| 19 |
+
"K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr",
|
| 20 |
+
"Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe",
|
| 21 |
+
"Cs", "Ba",
|
| 22 |
+
"La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu",
|
| 23 |
+
"Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn",
|
| 24 |
+
"Fr", "Ra",
|
| 25 |
+
"Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr",
|
| 26 |
+
"Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
|
| 27 |
+
]
|
PhysDock/data/constants/residue_constants.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
amino_acid_1to3 = {
|
| 4 |
+
"A": "ALA",
|
| 5 |
+
"R": "ARG",
|
| 6 |
+
"N": "ASN",
|
| 7 |
+
"D": "ASP",
|
| 8 |
+
"C": "CYS",
|
| 9 |
+
"Q": "GLN",
|
| 10 |
+
"E": "GLU",
|
| 11 |
+
"G": "GLY",
|
| 12 |
+
"H": "HIS",
|
| 13 |
+
"I": "ILE",
|
| 14 |
+
"L": "LEU",
|
| 15 |
+
"K": "LYS",
|
| 16 |
+
"M": "MET",
|
| 17 |
+
"F": "PHE",
|
| 18 |
+
"P": "PRO",
|
| 19 |
+
"S": "SER",
|
| 20 |
+
"T": "THR",
|
| 21 |
+
"W": "TRP",
|
| 22 |
+
"Y": "TYR",
|
| 23 |
+
"V": "VAL",
|
| 24 |
+
"X": "UNK",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
amino_acid_3to1 = {v: k for k, v in amino_acid_1to3.items()}
|
| 28 |
+
|
| 29 |
+
# Ligand Atom is representaed as "UNK" in token
|
| 30 |
+
# standard_residue is also ccd
|
| 31 |
+
standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 32 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
|
| 33 |
+
standard_rna = ["A ", "G ", "C ", "U ", "N ", ]
|
| 34 |
+
standard_dna = ["DA ", "DG ", "DC ", "DT ", "DN ", ]
|
| 35 |
+
standard_nucleics = standard_rna + standard_dna
|
| 36 |
+
standard_ccds_without_gap = standard_protein + standard_nucleics
|
| 37 |
+
GAP = ["GAP"] # used in msa one-hot
|
| 38 |
+
standard_ccds = standard_protein + standard_nucleics + GAP
|
| 39 |
+
|
| 40 |
+
standard_ccd_to_order = {ccd: id for id, ccd in enumerate(standard_ccds)}
|
| 41 |
+
|
| 42 |
+
standard_purines = ["A ", "G ", "DA ", "DG "]
|
| 43 |
+
standard_pyrimidines = ["C ", "U ", "DC ", "DT "]
|
| 44 |
+
|
| 45 |
+
is_standard = lambda x: x in standard_ccds
|
| 46 |
+
is_unk = lambda x: x in ["UNK", "N ", "DN ", "GAP", "UNL"]
|
| 47 |
+
is_protein = lambda x: x in standard_protein and not is_unk(x)
|
| 48 |
+
is_rna = lambda x: x in standard_rna and not is_unk(x)
|
| 49 |
+
is_dna = lambda x: x in standard_dna and not is_unk(x)
|
| 50 |
+
is_nucleics = lambda x: x in standard_nucleics and not is_unk(x)
|
| 51 |
+
is_purines = lambda x: x in standard_purines
|
| 52 |
+
is_pyrimidines = lambda x: x in standard_pyrimidines
|
| 53 |
+
|
| 54 |
+
standard_ccd_to_atoms_num = {s: n for s, n in zip(standard_ccds, [
|
| 55 |
+
5, 11, 8, 8, 6, 9, 9, 4, 10, 8,
|
| 56 |
+
8, 9, 8, 11, 7, 6, 7, 14, 12, 7, None,
|
| 57 |
+
22, 23, 20, 20, None,
|
| 58 |
+
21, 22, 19, 20, None,
|
| 59 |
+
None,
|
| 60 |
+
])}
|
| 61 |
+
|
| 62 |
+
standard_ccd_to_token_centre_atom_name = {
|
| 63 |
+
**{residue: "CA" for residue in standard_protein},
|
| 64 |
+
**{residue: "C1'" for residue in standard_nucleics},
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
standard_ccd_to_frame_atom_name_0 = {
|
| 68 |
+
**{residue: "N" for residue in standard_protein},
|
| 69 |
+
**{residue: "C1'" for residue in standard_nucleics},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
standard_ccd_to_frame_atom_name_1 = {
|
| 73 |
+
**{residue: "CA" for residue in standard_protein},
|
| 74 |
+
**{residue: "C3'" for residue in standard_nucleics},
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
standard_ccd_to_frame_atom_name_2 = {
|
| 78 |
+
**{residue: "C" for residue in standard_protein},
|
| 79 |
+
**{residue: "C4'" for residue in standard_nucleics},
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
standard_ccd_to_token_pseudo_beta_atom_name = {
|
| 83 |
+
**{residue: "CB" for residue in standard_protein},
|
| 84 |
+
**{residue: "C4" for residue in standard_purines},
|
| 85 |
+
**{residue: "C2" for residue in standard_pyrimidines},
|
| 86 |
+
}
|
| 87 |
+
standard_ccd_to_token_pseudo_beta_atom_name.update({"GLY": "CA"})
|
| 88 |
+
|
| 89 |
+
########################################################
|
| 90 |
+
# periodic table that used to encode elements #
|
| 91 |
+
########################################################
|
| 92 |
+
periodic_table = [
|
| 93 |
+
"h", "he",
|
| 94 |
+
"li", "be", "b", "c", "n", "o", "f", "ne",
|
| 95 |
+
"na", "mg", "al", "si", "p", "s", "cl", "ar",
|
| 96 |
+
"k", "ca", "sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn", "ga", "ge", "as", "se", "br", "kr",
|
| 97 |
+
"rb", "sr", "y", "zr", "nb", "mo", "tc", "ru", "rh", "pd", "ag", "cd", "in", "sn", "sb", "te", "i", "xe",
|
| 98 |
+
"cs", "ba",
|
| 99 |
+
"la", "ce", "pr", "nd", "pm", "sm", "eu", "gd", "tb", "dy", "ho", "er", "tm", "yb", "lu",
|
| 100 |
+
"hf", "ta", "w", "re", "os", "ir", "pt", "au", "hg", "tl", "pb", "bi", "po", "at", "rn",
|
| 101 |
+
"fr", "ra",
|
| 102 |
+
"ac", "th", "pa", "u", "np", "pu", "am", "cm", "bk", "cf", "es", "fm", "md", "no", "lr",
|
| 103 |
+
"rf", "db", "sg", "bh", "hs", "mt", "ds", "rg", "cn", "nh", "fl", "mc", "lv", "ts", "og"
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
get_element_id = {ele: ele_id for ele_id, ele in enumerate(periodic_table)}
|
| 107 |
+
|
| 108 |
+
##########################################################
|
| 109 |
+
|
| 110 |
+
standard_ccd_to_reference_features_table = {
|
| 111 |
+
# letters_3: [ref_pos,ref_charge, ref_mask, ref_elements, ref_atom_name_chars]
|
| 112 |
+
"ALA": [
|
| 113 |
+
[-0.966, 0.493, 1.500, 0., 1, "N", "N"],
|
| 114 |
+
[0.257, 0.418, 0.692, 0., 1, "C", "CA"],
|
| 115 |
+
[-0.094, 0.017, -0.716, 0., 1, "C", "C"],
|
| 116 |
+
[-1.056, -0.682, -0.923, 0., 1, "O", "O"],
|
| 117 |
+
[1.204, -0.620, 1.296, 0., 1, "C", "CB"],
|
| 118 |
+
[0.661, 0.439, -1.742, 0., 0, "O", "OXT"],
|
| 119 |
+
],
|
| 120 |
+
"ARG": [
|
| 121 |
+
[-0.469, 1.110, -0.993, 0., 1, "N", "N"],
|
| 122 |
+
[0.004, 2.294, -1.708, 0., 1, "C", "CA"],
|
| 123 |
+
[-0.907, 2.521, -2.901, 0., 1, "C", "C"],
|
| 124 |
+
[-1.827, 1.789, -3.242, 0., 1, "O", "O"],
|
| 125 |
+
[1.475, 2.150, -2.127, 0., 1, "C", "CB"],
|
| 126 |
+
[1.745, 1.017, -3.130, 0., 1, "C", "CG"],
|
| 127 |
+
[3.210, 0.954, -3.557, 0., 1, "C", "CD"],
|
| 128 |
+
[4.071, 0.726, -2.421, 0., 1, "N", "NE"],
|
| 129 |
+
[5.469, 0.624, -2.528, 0., 1, "C", "CZ"],
|
| 130 |
+
[6.259, 0.404, -1.405, 0., 1, "N", "NH1"],
|
| 131 |
+
[6.078, 0.744, -3.773, 0., 1, "N", "NH2"],
|
| 132 |
+
[-0.588, 3.659, -3.574, 0., 0, "O", "OXT"],
|
| 133 |
+
],
|
| 134 |
+
"ASN": [
|
| 135 |
+
[-0.293, 1.686, 0.094, 0., 1, "N", "N"],
|
| 136 |
+
[-0.448, 0.292, -0.340, 0., 1, "C", "CA"],
|
| 137 |
+
[-1.846, -0.179, -0.031, 0., 1, "C", "C"],
|
| 138 |
+
[-2.510, 0.402, 0.794, 0., 1, "O", "O"],
|
| 139 |
+
[0.562, -0.588, 0.401, 0., 1, "C", "CB"],
|
| 140 |
+
[1.960, -0.197, -0.002, 0., 1, "C", "CG"],
|
| 141 |
+
[2.132, 0.697, -0.804, 0., 1, "O", "OD1"],
|
| 142 |
+
[3.019, -0.841, 0.527, 0., 1, "N", "ND2"],
|
| 143 |
+
[-2.353, -1.243, -0.673, 0., 0, "O", "OXT"],
|
| 144 |
+
],
|
| 145 |
+
"ASP": [
|
| 146 |
+
[-0.317, 1.688, 0.066, 0., 1, "N", "N"],
|
| 147 |
+
[-0.470, 0.286, -0.344, 0., 1, "C", "CA"],
|
| 148 |
+
[-1.868, -0.180, -0.029, 0., 1, "C", "C"],
|
| 149 |
+
[-2.534, 0.415, 0.786, 0., 1, "O", "O"],
|
| 150 |
+
[0.539, -0.580, 0.413, 0., 1, "C", "CB"],
|
| 151 |
+
[1.938, -0.195, 0.004, 0., 1, "C", "CG"],
|
| 152 |
+
[2.109, 0.681, -0.810, 0., 1, "O", "OD1"],
|
| 153 |
+
[2.992, -0.826, 0.543, 0., 1, "O", "OD2"],
|
| 154 |
+
[-2.374, -1.256, -0.652, 0., 0, "O", "OXT"],
|
| 155 |
+
],
|
| 156 |
+
"CYS": [
|
| 157 |
+
[1.585, 0.483, -0.081, 0., 1, "N", "N"],
|
| 158 |
+
[0.141, 0.450, 0.186, 0., 1, "C", "CA"],
|
| 159 |
+
[-0.095, 0.006, 1.606, 0., 1, "C", "C"],
|
| 160 |
+
[0.685, -0.742, 2.143, 0., 1, "O", "O"],
|
| 161 |
+
[-0.533, -0.530, -0.774, 0., 1, "C", "CB"],
|
| 162 |
+
[-0.247, 0.004, -2.484, 0., 1, "S", "SG"],
|
| 163 |
+
[-1.174, 0.443, 2.275, 0., 0, "O", "OXT"],
|
| 164 |
+
],
|
| 165 |
+
"GLN": [
|
| 166 |
+
[1.858, -0.148, 1.125, 0., 1, "N", "N"],
|
| 167 |
+
[0.517, 0.451, 1.112, 0., 1, "C", "CA"],
|
| 168 |
+
[-0.236, 0.022, 2.344, 0., 1, "C", "C"],
|
| 169 |
+
[-0.005, -1.049, 2.851, 0., 1, "O", "O"],
|
| 170 |
+
[-0.236, -0.013, -0.135, 0., 1, "C", "CB"],
|
| 171 |
+
[0.529, 0.421, -1.385, 0., 1, "C", "CG"],
|
| 172 |
+
[-0.213, -0.036, -2.614, 0., 1, "C", "CD"],
|
| 173 |
+
[-1.252, -0.650, -2.500, 0., 1, "O", "OE1"],
|
| 174 |
+
[0.277, 0.236, -3.839, 0., 1, "N", "NE2"],
|
| 175 |
+
[-1.165, 0.831, 2.878, 0., 0, "O", "OXT"],
|
| 176 |
+
],
|
| 177 |
+
"GLU": [
|
| 178 |
+
[1.199, 1.867, -0.117, 0., 1, "N", "N"],
|
| 179 |
+
[1.138, 0.515, 0.453, 0., 1, "C", "CA"],
|
| 180 |
+
[2.364, -0.260, 0.041, 0., 1, "C", "C"],
|
| 181 |
+
[3.010, 0.096, -0.916, 0., 1, "O", "O"],
|
| 182 |
+
[-0.113, -0.200, -0.062, 0., 1, "C", "CB"],
|
| 183 |
+
[-1.360, 0.517, 0.461, 0., 1, "C", "CG"],
|
| 184 |
+
[-2.593, -0.187, -0.046, 0., 1, "C", "CD"],
|
| 185 |
+
[-2.485, -1.161, -0.753, 0., 1, "O", "OE1"],
|
| 186 |
+
[-3.811, 0.269, 0.287, 0., 1, "O", "OE2"],
|
| 187 |
+
[2.737, -1.345, 0.737, 0., 0, "O", "OXT"],
|
| 188 |
+
],
|
| 189 |
+
"GLY": [
|
| 190 |
+
[1.931, 0.090, -0.034, 0., 1, "N", "N"],
|
| 191 |
+
[0.761, -0.799, -0.008, 0., 1, "C", "CA"],
|
| 192 |
+
[-0.498, 0.029, -0.005, 0., 1, "C", "C"],
|
| 193 |
+
[-0.429, 1.235, -0.023, 0., 1, "O", "O"],
|
| 194 |
+
[-1.697, -0.574, 0.018, 0., 0, "O", "OXT"],
|
| 195 |
+
],
|
| 196 |
+
"HIS": [
|
| 197 |
+
[-0.040, -1.210, 0.053, 0., 1, "N", "N"],
|
| 198 |
+
[1.172, -1.709, 0.652, 0., 1, "C", "CA"],
|
| 199 |
+
[1.083, -3.207, 0.905, 0., 1, "C", "C"],
|
| 200 |
+
[0.040, -3.770, 1.222, 0., 1, "O", "O"],
|
| 201 |
+
[1.484, -0.975, 1.962, 0., 1, "C", "CB"],
|
| 202 |
+
[2.940, -1.060, 2.353, 0., 1, "C", "CG"],
|
| 203 |
+
[3.380, -2.075, 3.129, 0., 1, "N", "ND1"],
|
| 204 |
+
[3.960, -0.251, 2.046, 0., 1, "C", "CD2"],
|
| 205 |
+
[4.693, -1.908, 3.317, 0., 1, "C", "CE1"],
|
| 206 |
+
[5.058, -0.801, 2.662, 0., 1, "N", "NE2"],
|
| 207 |
+
[2.247, -3.882, 0.744, 0., 0, "O", "OXT"],
|
| 208 |
+
],
|
| 209 |
+
"ILE": [
|
| 210 |
+
[-1.944, 0.335, -0.343, 0., 1, "N", "N"],
|
| 211 |
+
[-0.487, 0.519, -0.369, 0., 1, "C", "CA"],
|
| 212 |
+
[0.066, -0.032, -1.657, 0., 1, "C", "C"],
|
| 213 |
+
[-0.484, -0.958, -2.203, 0., 1, "O", "O"],
|
| 214 |
+
[0.140, -0.219, 0.814, 0., 1, "C", "CB"],
|
| 215 |
+
[-0.421, 0.341, 2.122, 0., 1, "C", "CG1"],
|
| 216 |
+
[1.658, -0.027, 0.788, 0., 1, "C", "CG2"],
|
| 217 |
+
[0.206, -0.397, 3.305, 0., 1, "C", "CD1"],
|
| 218 |
+
[1.171, 0.504, -2.197, 0., 0, "O", "OXT"],
|
| 219 |
+
],
|
| 220 |
+
"LEU": [
|
| 221 |
+
[-1.661, 0.627, -0.406, 0., 1, "N", "N"],
|
| 222 |
+
[-0.205, 0.441, -0.467, 0., 1, "C", "CA"],
|
| 223 |
+
[0.180, -0.055, -1.836, 0., 1, "C", "C"],
|
| 224 |
+
[-0.591, -0.731, -2.474, 0., 1, "O", "O"],
|
| 225 |
+
[0.221, -0.583, 0.585, 0., 1, "C", "CB"],
|
| 226 |
+
[-0.170, -0.079, 1.976, 0., 1, "C", "CG"],
|
| 227 |
+
[0.256, -1.104, 3.029, 0., 1, "C", "CD1"],
|
| 228 |
+
[0.526, 1.254, 2.250, 0., 1, "C", "CD2"],
|
| 229 |
+
[1.382, 0.254, -2.348, 0., 0, "O", "OXT"],
|
| 230 |
+
],
|
| 231 |
+
"LYS": [
|
| 232 |
+
[1.422, 1.796, 0.198, 0., 1, "N", "N"],
|
| 233 |
+
[1.394, 0.355, 0.484, 0., 1, "C", "CA"],
|
| 234 |
+
[2.657, -0.284, -0.032, 0., 1, "C", "C"],
|
| 235 |
+
[3.316, 0.275, -0.876, 0., 1, "O", "O"],
|
| 236 |
+
[0.184, -0.278, -0.206, 0., 1, "C", "CB"],
|
| 237 |
+
[-1.102, 0.282, 0.407, 0., 1, "C", "CG"],
|
| 238 |
+
[-2.313, -0.351, -0.283, 0., 1, "C", "CD"],
|
| 239 |
+
[-3.598, 0.208, 0.329, 0., 1, "C", "CE"],
|
| 240 |
+
[-4.761, -0.400, -0.332, 0., 1, "N", "NZ"],
|
| 241 |
+
[3.050, -1.476, 0.446, 0., 0, "O", "OXT"],
|
| 242 |
+
],
|
| 243 |
+
"MET": [
|
| 244 |
+
[-1.816, 0.142, -1.166, 0., 1, "N", "N"],
|
| 245 |
+
[-0.392, 0.499, -1.214, 0., 1, "C", "CA"],
|
| 246 |
+
[0.206, 0.002, -2.504, 0., 1, "C", "C"],
|
| 247 |
+
[-0.236, -0.989, -3.033, 0., 1, "O", "O"],
|
| 248 |
+
[0.334, -0.145, -0.032, 0., 1, "C", "CB"],
|
| 249 |
+
[-0.273, 0.359, 1.277, 0., 1, "C", "CG"],
|
| 250 |
+
[0.589, -0.405, 2.678, 0., 1, "S", "SD"],
|
| 251 |
+
[-0.314, 0.353, 4.056, 0., 1, "C", "CE"],
|
| 252 |
+
[1.232, 0.661, -3.066, 0., 0, "O", "OXT"],
|
| 253 |
+
],
|
| 254 |
+
"PHE": [
|
| 255 |
+
[1.317, 0.962, 1.014, 0., 1, "N", "N"],
|
| 256 |
+
[-0.020, 0.426, 1.300, 0., 1, "C", "CA"],
|
| 257 |
+
[-0.109, 0.047, 2.756, 0., 1, "C", "C"],
|
| 258 |
+
[0.879, -0.317, 3.346, 0., 1, "O", "O"],
|
| 259 |
+
[-0.270, -0.809, 0.434, 0., 1, "C", "CB"],
|
| 260 |
+
[-0.181, -0.430, -1.020, 0., 1, "C", "CG"],
|
| 261 |
+
[1.031, -0.498, -1.680, 0., 1, "C", "CD1"],
|
| 262 |
+
[-1.314, -0.018, -1.698, 0., 1, "C", "CD2"],
|
| 263 |
+
[1.112, -0.150, -3.015, 0., 1, "C", "CE1"],
|
| 264 |
+
[-1.231, 0.333, -3.032, 0., 1, "C", "CE2"],
|
| 265 |
+
[-0.018, 0.265, -3.691, 0., 1, "C", "CZ"],
|
| 266 |
+
[-1.286, 0.113, 3.396, 0., 0, "O", "OXT"],
|
| 267 |
+
],
|
| 268 |
+
"PRO": [
|
| 269 |
+
[-0.816, 1.108, 0.254, 0., 1, "N", "N"],
|
| 270 |
+
[0.001, -0.107, 0.509, 0., 1, "C", "CA"],
|
| 271 |
+
[1.408, 0.091, 0.005, 0., 1, "C", "C"],
|
| 272 |
+
[1.650, 0.980, -0.777, 0., 1, "O", "O"],
|
| 273 |
+
[-0.703, -1.227, -0.286, 0., 1, "C", "CB"],
|
| 274 |
+
[-2.163, -0.753, -0.439, 0., 1, "C", "CG"],
|
| 275 |
+
[-2.218, 0.614, 0.276, 0., 1, "C", "CD"],
|
| 276 |
+
[2.391, -0.721, 0.424, 0., 0, "O", "OXT"],
|
| 277 |
+
],
|
| 278 |
+
"SER": [
|
| 279 |
+
[1.525, 0.493, -0.608, 0., 1, "N", "N"],
|
| 280 |
+
[0.100, 0.469, -0.252, 0., 1, "C", "CA"],
|
| 281 |
+
[-0.053, 0.004, 1.173, 0., 1, "C", "C"],
|
| 282 |
+
[0.751, -0.760, 1.649, 0., 1, "O", "O"],
|
| 283 |
+
[-0.642, -0.489, -1.184, 0., 1, "C", "CB"],
|
| 284 |
+
[-0.496, -0.049, -2.535, 0., 1, "O", "OG"],
|
| 285 |
+
[-1.084, 0.440, 1.913, 0., 0, "O", "OXT"],
|
| 286 |
+
],
|
| 287 |
+
"THR": [
|
| 288 |
+
[1.543, -0.702, 0.430, 0., 1, "N", "N"],
|
| 289 |
+
[0.122, -0.706, 0.056, 0., 1, "C", "CA"],
|
| 290 |
+
[-0.038, -0.090, -1.309, 0., 1, "C", "C"],
|
| 291 |
+
[0.732, 0.761, -1.683, 0., 1, "O", "O"],
|
| 292 |
+
[-0.675, 0.104, 1.079, 0., 1, "C", "CB"],
|
| 293 |
+
[-0.193, 1.448, 1.103, 0., 1, "O", "OG1"],
|
| 294 |
+
[-0.511, -0.521, 2.466, 0., 1, "C", "CG2"],
|
| 295 |
+
[-1.039, -0.488, -2.110, 0., 0, "O", "OXT"],
|
| 296 |
+
],
|
| 297 |
+
"TRP": [
|
| 298 |
+
[1.278, 1.121, 2.059, 0., 1, "N", "N"],
|
| 299 |
+
[-0.008, 0.417, 1.970, 0., 1, "C", "CA"],
|
| 300 |
+
[-0.490, 0.076, 3.357, 0., 1, "C", "C"],
|
| 301 |
+
[0.308, -0.130, 4.240, 0., 1, "O", "O"],
|
| 302 |
+
[0.168, -0.868, 1.161, 0., 1, "C", "CB"],
|
| 303 |
+
[0.650, -0.526, -0.225, 0., 1, "C", "CG"],
|
| 304 |
+
[1.928, -0.418, -0.622, 0., 1, "C", "CD1"],
|
| 305 |
+
[-0.186, -0.256, -1.396, 0., 1, "C", "CD2"],
|
| 306 |
+
[1.978, -0.095, -1.951, 0., 1, "N", "NE1"],
|
| 307 |
+
[0.701, 0.014, -2.454, 0., 1, "C", "CE2"],
|
| 308 |
+
[-1.564, -0.210, -1.615, 0., 1, "C", "CE3"],
|
| 309 |
+
[0.190, 0.314, -3.712, 0., 1, "C", "CZ2"],
|
| 310 |
+
[-2.044, 0.086, -2.859, 0., 1, "C", "CZ3"],
|
| 311 |
+
[-1.173, 0.348, -3.907, 0., 1, "C", "CH2"],
|
| 312 |
+
[-1.806, 0.001, 3.610, 0., 0, "O", "OXT"],
|
| 313 |
+
],
|
| 314 |
+
"TYR": [
|
| 315 |
+
[1.320, 0.952, 1.428, 0., 1, "N", "N"],
|
| 316 |
+
[-0.018, 0.429, 1.734, 0., 1, "C", "CA"],
|
| 317 |
+
[-0.103, 0.094, 3.201, 0., 1, "C", "C"],
|
| 318 |
+
[0.886, -0.254, 3.799, 0., 1, "O", "O"],
|
| 319 |
+
[-0.274, -0.831, 0.907, 0., 1, "C", "CB"],
|
| 320 |
+
[-0.189, -0.496, -0.559, 0., 1, "C", "CG"],
|
| 321 |
+
[1.022, -0.589, -1.219, 0., 1, "C", "CD1"],
|
| 322 |
+
[-1.324, -0.102, -1.244, 0., 1, "C", "CD2"],
|
| 323 |
+
[1.103, -0.282, -2.563, 0., 1, "C", "CE1"],
|
| 324 |
+
[-1.247, 0.210, -2.587, 0., 1, "C", "CE2"],
|
| 325 |
+
[-0.032, 0.118, -3.252, 0., 1, "C", "CZ"],
|
| 326 |
+
[0.044, 0.420, -4.574, 0., 1, "O", "OH"],
|
| 327 |
+
[-1.279, 0.184, 3.842, 0., 0, "O", "OXT"],
|
| 328 |
+
],
|
| 329 |
+
"VAL": [
|
| 330 |
+
[1.564, -0.642, 0.454, 0., 1, "N", "N"],
|
| 331 |
+
[0.145, -0.698, 0.079, 0., 1, "C", "CA"],
|
| 332 |
+
[-0.037, -0.093, -1.288, 0., 1, "C", "C"],
|
| 333 |
+
[0.703, 0.784, -1.664, 0., 1, "O", "O"],
|
| 334 |
+
[-0.682, 0.086, 1.098, 0., 1, "C", "CB"],
|
| 335 |
+
[-0.497, -0.528, 2.487, 0., 1, "C", "CG1"],
|
| 336 |
+
[-0.218, 1.543, 1.119, 0., 1, "C", "CG2"],
|
| 337 |
+
[-1.022, -0.529, -2.089, 0., 0, "O", "OXT"],
|
| 338 |
+
],
|
| 339 |
+
"A ": [
|
| 340 |
+
[2.135, -1.141, -5.313, 0., 0, "O", "OP3"],
|
| 341 |
+
[1.024, -0.137, -4.723, 0., 1, "P", "P"],
|
| 342 |
+
[1.633, 1.190, -4.488, 0., 1, "O", "OP1"],
|
| 343 |
+
[-0.183, 0.005, -5.778, 0., 1, "O", "OP2"],
|
| 344 |
+
[0.456, -0.720, -3.334, 0., 1, "O", "O5'"],
|
| 345 |
+
[-0.520, 0.209, -2.863, 0., 1, "C", "C5'"],
|
| 346 |
+
[-1.101, -0.287, -1.538, 0., 1, "C", "C4'"],
|
| 347 |
+
[-0.064, -0.383, -0.538, 0., 1, "O", "O4'"],
|
| 348 |
+
[-2.105, 0.739, -0.969, 0., 1, "C", "C3'"],
|
| 349 |
+
[-3.445, 0.360, -1.287, 0., 1, "O", "O3'"],
|
| 350 |
+
[-1.874, 0.684, 0.558, 0., 1, "C", "C2'"],
|
| 351 |
+
[-3.065, 0.271, 1.231, 0., 1, "O", "O2'"],
|
| 352 |
+
[-0.755, -0.367, 0.729, 0., 1, "C", "C1'"],
|
| 353 |
+
[0.158, 0.029, 1.803, 0., 1, "N", "N9"],
|
| 354 |
+
[1.265, 0.813, 1.672, 0., 1, "C", "C8"],
|
| 355 |
+
[1.843, 0.963, 2.828, 0., 1, "N", "N7"],
|
| 356 |
+
[1.143, 0.292, 3.773, 0., 1, "C", "C5"],
|
| 357 |
+
[1.290, 0.091, 5.156, 0., 1, "C", "C6"],
|
| 358 |
+
[2.344, 0.664, 5.846, 0., 1, "N", "N6"],
|
| 359 |
+
[0.391, -0.656, 5.787, 0., 1, "N", "N1"],
|
| 360 |
+
[-0.617, -1.206, 5.136, 0., 1, "C", "C2"],
|
| 361 |
+
[-0.792, -1.051, 3.841, 0., 1, "N", "N3"],
|
| 362 |
+
[0.056, -0.320, 3.126, 0., 1, "C", "C4"],
|
| 363 |
+
],
|
| 364 |
+
"G ": [
|
| 365 |
+
[-1.945, -1.360, 5.599, 0., 0, "O", "OP3"],
|
| 366 |
+
[-0.911, -0.277, 5.008, 0., 1, "P", "P"],
|
| 367 |
+
[-1.598, 1.022, 4.844, 0., 1, "O", "OP1"],
|
| 368 |
+
[0.325, -0.105, 6.025, 0., 1, "O", "OP2"],
|
| 369 |
+
[-0.365, -0.780, 3.580, 0., 1, "O", "O5'"],
|
| 370 |
+
[0.542, 0.217, 3.109, 0., 1, "C", "C5'"],
|
| 371 |
+
[1.100, -0.200, 1.748, 0., 1, "C", "C4'"],
|
| 372 |
+
[0.033, -0.318, 0.782, 0., 1, "O", "O4'"],
|
| 373 |
+
[2.025, 0.898, 1.182, 0., 1, "C", "C3'"],
|
| 374 |
+
[3.395, 0.582, 1.439, 0., 1, "O", "O3'"],
|
| 375 |
+
[1.741, 0.884, -0.338, 0., 1, "C", "C2'"],
|
| 376 |
+
[2.927, 0.560, -1.066, 0., 1, "O", "O2'"],
|
| 377 |
+
[0.675, -0.220, -0.507, 0., 1, "C", "C1'"],
|
| 378 |
+
[-0.297, 0.162, -1.534, 0., 1, "N", "N9"],
|
| 379 |
+
[-1.440, 0.880, -1.334, 0., 1, "C", "C8"],
|
| 380 |
+
[-2.066, 1.037, -2.464, 0., 1, "N", "N7"],
|
| 381 |
+
[-1.364, 0.431, -3.453, 0., 1, "C", "C5"],
|
| 382 |
+
[-1.556, 0.279, -4.846, 0., 1, "C", "C6"],
|
| 383 |
+
[-2.534, 0.755, -5.397, 0., 1, "O", "O6"],
|
| 384 |
+
[-0.626, -0.401, -5.551, 0., 1, "N", "N1"],
|
| 385 |
+
[0.459, -0.934, -4.923, 0., 1, "C", "C2"],
|
| 386 |
+
[1.384, -1.626, -5.664, 0., 1, "N", "N2"],
|
| 387 |
+
[0.649, -0.800, -3.630, 0., 1, "N", "N3"],
|
| 388 |
+
[-0.226, -0.134, -2.868, 0., 1, "C", "C4"],
|
| 389 |
+
],
|
| 390 |
+
"C ": [
|
| 391 |
+
[2.147, -1.021, -4.678, 0., 0, "O", "OP3"],
|
| 392 |
+
[1.049, -0.039, -4.028, 0., 1, "P", "P"],
|
| 393 |
+
[1.692, 1.237, -3.646, 0., 1, "O", "OP1"],
|
| 394 |
+
[-0.116, 0.246, -5.102, 0., 1, "O", "OP2"],
|
| 395 |
+
[0.415, -0.733, -2.721, 0., 1, "O", "O5'"],
|
| 396 |
+
[-0.546, 0.181, -2.193, 0., 1, "C", "C5'"],
|
| 397 |
+
[-1.189, -0.419, -0.942, 0., 1, "C", "C4'"],
|
| 398 |
+
[-0.190, -0.648, 0.076, 0., 1, "O", "O4'"],
|
| 399 |
+
[-2.178, 0.583, -0.307, 0., 1, "C", "C3'"],
|
| 400 |
+
[-3.518, 0.283, -0.703, 0., 1, "O", "O3'"],
|
| 401 |
+
[-2.001, 0.373, 1.215, 0., 1, "C", "C2'"],
|
| 402 |
+
[-3.228, -0.059, 1.806, 0., 1, "O", "O2'"],
|
| 403 |
+
[-0.924, -0.729, 1.317, 0., 1, "C", "C1'"],
|
| 404 |
+
[-0.036, -0.470, 2.453, 0., 1, "N", "N1"],
|
| 405 |
+
[0.652, 0.683, 2.514, 0., 1, "C", "C2"],
|
| 406 |
+
[0.529, 1.504, 1.620, 0., 1, "O", "O2"],
|
| 407 |
+
[1.467, 0.945, 3.535, 0., 1, "N", "N3"],
|
| 408 |
+
[1.620, 0.070, 4.520, 0., 1, "C", "C4"],
|
| 409 |
+
[2.464, 0.350, 5.569, 0., 1, "N", "N4"],
|
| 410 |
+
[0.916, -1.151, 4.483, 0., 1, "C", "C5"],
|
| 411 |
+
[0.087, -1.399, 3.442, 0., 1, "C", "C6"],
|
| 412 |
+
],
|
| 413 |
+
"U ": [
|
| 414 |
+
[-2.122, 1.033, -4.690, 0., 0, "O", "OP3"],
|
| 415 |
+
[-1.030, 0.047, -4.037, 0., 1, "P", "P"],
|
| 416 |
+
[-1.679, -1.228, -3.660, 0., 1, "O", "OP1"],
|
| 417 |
+
[0.138, -0.241, -5.107, 0., 1, "O", "OP2"],
|
| 418 |
+
[-0.399, 0.736, -2.726, 0., 1, "O", "O5'"],
|
| 419 |
+
[0.557, -0.182, -2.196, 0., 1, "C", "C5'"],
|
| 420 |
+
[1.197, 0.415, -0.942, 0., 1, "C", "C4'"],
|
| 421 |
+
[0.194, 0.645, 0.074, 0., 1, "O", "O4'"],
|
| 422 |
+
[2.181, -0.588, -0.301, 0., 1, "C", "C3'"],
|
| 423 |
+
[3.524, -0.288, -0.686, 0., 1, "O", "O3'"],
|
| 424 |
+
[1.995, -0.383, 1.218, 0., 1, "C", "C2'"],
|
| 425 |
+
[3.219, 0.046, 1.819, 0., 1, "O", "O2'"],
|
| 426 |
+
[0.922, 0.723, 1.319, 0., 1, "C", "C1'"],
|
| 427 |
+
[0.028, 0.464, 2.451, 0., 1, "N", "N1"],
|
| 428 |
+
[-0.690, -0.671, 2.486, 0., 1, "C", "C2"],
|
| 429 |
+
[-0.587, -1.474, 1.580, 0., 1, "O", "O2"],
|
| 430 |
+
[-1.515, -0.936, 3.517, 0., 1, "N", "N3"],
|
| 431 |
+
[-1.641, -0.055, 4.530, 0., 1, "C", "C4"],
|
| 432 |
+
[-2.391, -0.292, 5.460, 0., 1, "O", "O4"],
|
| 433 |
+
[-0.894, 1.146, 4.502, 0., 1, "C", "C5"],
|
| 434 |
+
[-0.070, 1.384, 3.459, 0., 1, "C", "C6"],
|
| 435 |
+
],
|
| 436 |
+
"DA ": [
|
| 437 |
+
[1.845, -1.282, -5.339, 0., 0, "O", "OP3"],
|
| 438 |
+
[0.934, -0.156, -4.636, 0., 1, "P", "P"],
|
| 439 |
+
[1.781, 0.996, -4.255, 0., 1, "O", "OP1"],
|
| 440 |
+
[-0.204, 0.331, -5.665, 0., 1, "O", "OP2"],
|
| 441 |
+
[0.241, -0.771, -3.320, 0., 1, "O", "O5'"],
|
| 442 |
+
[-0.549, 0.270, -2.744, 0., 1, "C", "C5'"],
|
| 443 |
+
[-1.239, -0.251, -1.482, 0., 1, "C", "C4'"],
|
| 444 |
+
[-0.267, -0.564, -0.458, 0., 1, "O", "O4'"],
|
| 445 |
+
[-2.105, 0.859, -0.835, 0., 1, "C", "C3'"],
|
| 446 |
+
[-3.409, 0.895, -1.418, 0., 1, "O", "O3'"],
|
| 447 |
+
[-2.173, 0.398, 0.640, 0., 1, "C", "C2'"],
|
| 448 |
+
[-0.965, -0.545, 0.797, 0., 1, "C", "C1'"],
|
| 449 |
+
[-0.078, -0.047, 1.852, 0., 1, "N", "N9"],
|
| 450 |
+
[0.962, 0.817, 1.689, 0., 1, "C", "C8"],
|
| 451 |
+
[1.535, 1.044, 2.835, 0., 1, "N", "N7"],
|
| 452 |
+
[0.897, 0.346, 3.805, 0., 1, "C", "C5"],
|
| 453 |
+
[1.069, 0.196, 5.191, 0., 1, "C", "C6"],
|
| 454 |
+
[2.079, 0.869, 5.856, 0., 1, "N", "N6"],
|
| 455 |
+
[0.236, -0.603, 5.850, 0., 1, "N", "N1"],
|
| 456 |
+
[-0.729, -1.249, 5.224, 0., 1, "C", "C2"],
|
| 457 |
+
[-0.925, -1.144, 3.927, 0., 1, "N", "N3"],
|
| 458 |
+
[-0.142, -0.368, 3.184, 0., 1, "C", "C4"],
|
| 459 |
+
],
|
| 460 |
+
"DG ": [
|
| 461 |
+
[-1.603, -1.547, 5.624, 0., 0, "O", "OP3"],
|
| 462 |
+
[-0.818, -0.321, 4.935, 0., 1, "P", "P"],
|
| 463 |
+
[-1.774, 0.766, 4.630, 0., 1, "O", "OP1"],
|
| 464 |
+
[0.312, 0.224, 5.941, 0., 1, "O", "OP2"],
|
| 465 |
+
[-0.126, -0.826, 3.572, 0., 1, "O", "O5'"],
|
| 466 |
+
[0.550, 0.300, 3.011, 0., 1, "C", "C5'"],
|
| 467 |
+
[1.233, -0.113, 1.706, 0., 1, "C", "C4'"],
|
| 468 |
+
[0.253, -0.471, 0.705, 0., 1, "O", "O4'"],
|
| 469 |
+
[1.976, 1.091, 1.073, 0., 1, "C", "C3'"],
|
| 470 |
+
[3.294, 1.218, 1.612, 0., 1, "O", "O3'"],
|
| 471 |
+
[2.026, 0.692, -0.421, 0., 1, "C", "C2'"],
|
| 472 |
+
[0.897, -0.345, -0.573, 0., 1, "C", "C1'"],
|
| 473 |
+
[-0.068, 0.111, -1.575, 0., 1, "N", "N9"],
|
| 474 |
+
[-1.172, 0.877, -1.341, 0., 1, "C", "C8"],
|
| 475 |
+
[-1.804, 1.094, -2.458, 0., 1, "N", "N7"],
|
| 476 |
+
[-1.145, 0.482, -3.472, 0., 1, "C", "C5"],
|
| 477 |
+
[-1.361, 0.377, -4.866, 0., 1, "C", "C6"],
|
| 478 |
+
[-2.321, 0.914, -5.391, 0., 1, "O", "O6"],
|
| 479 |
+
[-0.473, -0.327, -5.601, 0., 1, "N", "N1"],
|
| 480 |
+
[0.593, -0.928, -5.003, 0., 1, "C", "C2"],
|
| 481 |
+
[1.474, -1.643, -5.774, 0., 1, "N", "N2"],
|
| 482 |
+
[0.804, -0.839, -3.709, 0., 1, "N", "N3"],
|
| 483 |
+
[-0.027, -0.152, -2.917, 0., 1, "C", "C4"],
|
| 484 |
+
],
|
| 485 |
+
"DC ": [
|
| 486 |
+
[1.941, -1.055, -4.672, 0., 0, "O", "OP3"],
|
| 487 |
+
[0.987, -0.017, -3.894, 0., 1, "P", "P"],
|
| 488 |
+
[1.802, 1.099, -3.365, 0., 1, "O", "OP1"],
|
| 489 |
+
[-0.119, 0.560, -4.910, 0., 1, "O", "OP2"],
|
| 490 |
+
[0.255, -0.772, -2.674, 0., 1, "O", "O5'"],
|
| 491 |
+
[-0.571, 0.196, -2.027, 0., 1, "C", "C5'"],
|
| 492 |
+
[-1.300, -0.459, -0.852, 0., 1, "C", "C4'"],
|
| 493 |
+
[-0.363, -0.863, 0.171, 0., 1, "O", "O4'"],
|
| 494 |
+
[-2.206, 0.569, -0.129, 0., 1, "C", "C3'"],
|
| 495 |
+
[-3.488, 0.649, -0.756, 0., 1, "O", "O3'"],
|
| 496 |
+
[-2.322, -0.040, 1.288, 0., 1, "C", "C2'"],
|
| 497 |
+
[-1.106, -0.981, 1.395, 0., 1, "C", "C1'"],
|
| 498 |
+
[-0.267, -0.584, 2.528, 0., 1, "N", "N1"],
|
| 499 |
+
[0.270, 0.648, 2.563, 0., 1, "C", "C2"],
|
| 500 |
+
[0.052, 1.424, 1.647, 0., 1, "O", "O2"],
|
| 501 |
+
[1.037, 1.035, 3.581, 0., 1, "N", "N3"],
|
| 502 |
+
[1.291, 0.212, 4.589, 0., 1, "C", "C4"],
|
| 503 |
+
[2.085, 0.622, 5.635, 0., 1, "N", "N4"],
|
| 504 |
+
[0.746, -1.088, 4.580, 0., 1, "C", "C5"],
|
| 505 |
+
[-0.035, -1.465, 3.541, 0., 1, "C", "C6"],
|
| 506 |
+
],
|
| 507 |
+
"DT ": [
|
| 508 |
+
[-3.912, -2.311, 1.636, 0., 0, "O", "OP3"],
|
| 509 |
+
[-3.968, -1.665, 3.118, 0., 1, "P", "P"],
|
| 510 |
+
[-4.406, -2.599, 4.208, 0., 1, "O", "OP1"],
|
| 511 |
+
[-4.901, -0.360, 2.920, 0., 1, "O", "OP2"],
|
| 512 |
+
[-2.493, -1.028, 3.315, 0., 1, "O", "O5'"],
|
| 513 |
+
[-2.005, -0.136, 2.327, 0., 1, "C", "C5'"],
|
| 514 |
+
[-0.611, 0.328, 2.728, 0., 1, "C", "C4'"],
|
| 515 |
+
[0.247, -0.829, 2.764, 0., 1, "O", "O4'"],
|
| 516 |
+
[0.008, 1.286, 1.720, 0., 1, "C", "C3'"],
|
| 517 |
+
[0.965, 2.121, 2.368, 0., 1, "O", "O3'"],
|
| 518 |
+
[0.710, 0.360, 0.754, 0., 1, "C", "C2'"],
|
| 519 |
+
[1.157, -0.778, 1.657, 0., 1, "C", "C1'"],
|
| 520 |
+
[1.164, -2.047, 0.989, 0., 1, "N", "N1"],
|
| 521 |
+
[2.333, -2.544, 0.374, 0., 1, "C", "C2"],
|
| 522 |
+
[3.410, -1.945, 0.363, 0., 1, "O", "O2"],
|
| 523 |
+
[2.194, -3.793, -0.240, 0., 1, "N", "N3"],
|
| 524 |
+
[1.047, -4.570, -0.300, 0., 1, "C", "C4"],
|
| 525 |
+
[0.995, -5.663, -0.857, 0., 1, "O", "O4"],
|
| 526 |
+
[-0.143, -3.980, 0.369, 0., 1, "C", "C5"],
|
| 527 |
+
[-1.420, -4.757, 0.347, 0., 1, "C", "C7"],
|
| 528 |
+
[-0.013, -2.784, 0.958, 0., 1, "C", "C6"],
|
| 529 |
+
],
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
standard_ccd_to_ref_atom_name_chars = {
|
| 533 |
+
ccd: [atom_ref_feats[-1] for atom_ref_feats in standard_ccd_to_reference_features_table[ccd]]
|
| 534 |
+
for ccd in standard_ccds if not is_unk(ccd)
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
eye_64 = np.eye(64)
|
| 538 |
+
eye_128 = np.eye(128)
|
| 539 |
+
eye_9 = np.eye(9)
|
| 540 |
+
eye_7 = np.eye(7)
|
| 541 |
+
eye_3 = np.eye(3)
|
| 542 |
+
eye_32 = np.eye(32)
|
| 543 |
+
eye_5 = np.eye(5)
|
| 544 |
+
eye8 = np.eye(8)
|
| 545 |
+
eye5 = np.eye(5)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def _get_ref_feat_from_ccd_data(ccd, ref_feat_table):
|
| 549 |
+
ref_feat = np.stack([
|
| 550 |
+
np.concatenate(
|
| 551 |
+
[np.array(atom_ref_feats[:5]), eye_128[get_element_id[atom_ref_feats[5].lower()]],
|
| 552 |
+
*[eye_64[ord(c) - 32] for c in f"{atom_ref_feats[-1]:<4}"]], axis=-1)
|
| 553 |
+
for atom_ref_feats in ref_feat_table[ccd]
|
| 554 |
+
], axis=0)
|
| 555 |
+
|
| 556 |
+
return ref_feat
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
standard_ccd_to_ref_feat = {
|
| 560 |
+
ccd: _get_ref_feat_from_ccd_data(ccd, standard_ccd_to_reference_features_table) for ccd in standard_ccds if
|
| 561 |
+
not is_unk(ccd)
|
| 562 |
+
}
|
PhysDock/data/constants/restype_constants.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from .PDBData import protein_letters_3to1_extended, nucleic_letters_3to1_extended
|
| 3 |
+
|
| 4 |
+
restype_1_to_3 = {
|
| 5 |
+
"A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS",
|
| 6 |
+
"Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE",
|
| 7 |
+
"L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO",
|
| 8 |
+
"S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL",
|
| 9 |
+
"X": "UNK",
|
| 10 |
+
"0": "A ", "1": "G ", "2": "C ", "3": "U ", "4": "N ",
|
| 11 |
+
"5": "DA ", "6": "DG ", "7": "DC ", "8": "DT ", "9": "DN ",
|
| 12 |
+
}
|
| 13 |
+
na_c_to_type = {
|
| 14 |
+
"A": "A ", "G": "G ", "C": "C ", "U": "U ", "N": "N ", "T": "T ", "X": "N "
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
restype_3_to_1 = {v: k for k, v in restype_1_to_3.items()}
|
| 18 |
+
restype_3_to_1["T "]="8"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
restypes3 = [
|
| 22 |
+
"ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 23 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK",
|
| 24 |
+
"A ", "G ", "C ", "U ", "N ",
|
| 25 |
+
"DA ", "DG ", "DC ", "DT ", "DN ",
|
| 26 |
+
]
|
| 27 |
+
restypes1 = [restype_3_to_1[ccd] for ccd in restypes3]
|
| 28 |
+
|
| 29 |
+
restype_3_to_1_extended = {}
|
| 30 |
+
|
| 31 |
+
for c3, c in protein_letters_3to1_extended.items():
|
| 32 |
+
restype_3_to_1_extended[f"{c3:<3}"] = c
|
| 33 |
+
# TODO: How to distinguish RNA and DNA
|
| 34 |
+
for c3, c in nucleic_letters_3to1_extended.items():
|
| 35 |
+
restype_3_to_1_extended[c3] = restype_3_to_1[na_c_to_type[c]]
|
| 36 |
+
restype_3_to_1_extended.update(restype_3_to_1)
|
| 37 |
+
|
| 38 |
+
############
|
| 39 |
+
standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 40 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
|
| 41 |
+
standard_rna = ["A ", "G ", "C ", "U ", "N ", ]
|
| 42 |
+
standard_dna = ["DA ", "DG ", "DC ", "DT ", "DN ", ]
|
| 43 |
+
standard_nucleics = standard_rna + standard_dna
|
| 44 |
+
standard_ccds_without_gap = standard_protein + standard_nucleics
|
| 45 |
+
GAP = ["GAP"] # used in msa one-hot
|
| 46 |
+
standard_ccds = standard_protein + standard_nucleics + GAP
|
| 47 |
+
|
| 48 |
+
standard_ccd_to_order = {ccd: id for id, ccd in enumerate(standard_ccds)}
|
| 49 |
+
|
| 50 |
+
standard_purines = ["A ", "G ", "DA ", "DG "]
|
| 51 |
+
standard_pyrimidines = ["C ", "U ", "DC ", "DT "]
|
| 52 |
+
|
| 53 |
+
is_standard = lambda x: x in standard_ccds
|
| 54 |
+
is_unk = lambda x: x in ["UNK", "N ", "DN ", "GAP", "UNL"]
|
| 55 |
+
is_protein = lambda x: x in standard_protein and not is_unk(x)
|
| 56 |
+
is_rna = lambda x: x in standard_rna and not is_unk(x)
|
| 57 |
+
is_dna = lambda x: x in standard_dna and not is_unk(x)
|
| 58 |
+
is_nucleics = lambda x: x in standard_nucleics and not is_unk(x)
|
| 59 |
+
is_purines = lambda x: x in standard_purines
|
| 60 |
+
is_pyrimidines = lambda x: x in standard_pyrimidines
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
standard_ccd_to_atoms_num = {s: n for s, n in zip(standard_ccds, [
|
| 64 |
+
5, 11, 8, 8, 6, 9, 9, 4, 10, 8,
|
| 65 |
+
8, 9, 8, 11, 7, 6, 7, 14, 12, 7, None,
|
| 66 |
+
22, 23, 20, 20, None,
|
| 67 |
+
21, 22, 19, 20, None,
|
| 68 |
+
None,
|
| 69 |
+
])}
|
| 70 |
+
|
| 71 |
+
standard_ccd_to_token_centre_atom_name = {
|
| 72 |
+
**{residue: "CA" for residue in standard_protein},
|
| 73 |
+
**{residue: "C1'" for residue in standard_nucleics},
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
standard_ccd_to_frame_atom_name_0 = {
|
| 77 |
+
**{residue: "N" for residue in standard_protein},
|
| 78 |
+
**{residue: "C1'" for residue in standard_nucleics},
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
standard_ccd_to_frame_atom_name_1 = {
|
| 82 |
+
**{residue: "CA" for residue in standard_protein},
|
| 83 |
+
**{residue: "C3'" for residue in standard_nucleics},
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
standard_ccd_to_frame_atom_name_2 = {
|
| 87 |
+
**{residue: "C" for residue in standard_protein},
|
| 88 |
+
**{residue: "C4'" for residue in standard_nucleics},
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
standard_ccd_to_token_pseudo_beta_atom_name = {
|
| 92 |
+
**{residue: "CB" for residue in standard_protein},
|
| 93 |
+
**{residue: "C4" for residue in standard_purines},
|
| 94 |
+
**{residue: "C2" for residue in standard_pyrimidines},
|
| 95 |
+
}
|
| 96 |
+
standard_ccd_to_token_pseudo_beta_atom_name.update({"GLY": "CA"})
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
eye_64 = np.eye(64)
|
| 100 |
+
eye_128 = np.eye(128)
|
| 101 |
+
eye_9 = np.eye(9)
|
| 102 |
+
eye_7 = np.eye(7)
|
| 103 |
+
eye_3 = np.eye(3)
|
| 104 |
+
eye_32 = np.eye(32)
|
| 105 |
+
eye_5 = np.eye(5)
|
| 106 |
+
eye8 = np.eye(8)
|
| 107 |
+
eye5 = np.eye(5)
|
PhysDock/data/feature_loader.py
ADDED
|
@@ -0,0 +1,1283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from functools import reduce
|
| 5 |
+
from operator import add
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from PhysDock.data.constants.PDBData import protein_letters_3to1_extended
|
| 12 |
+
from PhysDock.data.constants import restype_constants as rc
|
| 13 |
+
from PhysDock.utils.io_utils import convert_md5_string, load_json, load_pkl, dump_txt, find_files
|
| 14 |
+
from PhysDock.data.tools.feature_processing_multimer import pair_and_merge
|
| 15 |
+
from PhysDock.utils.tensor_utils import centre_random_augmentation_np_apply, dgram_from_positions, \
|
| 16 |
+
centre_random_augmentation_np_batch
|
| 17 |
+
from PhysDock.data.constants.periodic_table import PeriodicTable
|
| 18 |
+
from PhysDock.data.tools.rdkit import get_features_from_smi
|
| 19 |
+
|
| 20 |
+
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FeatureLoader:
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
# Dataset Config
|
| 27 |
+
dataset_path=None,
|
| 28 |
+
msa_features_dir=None,
|
| 29 |
+
ccd_id_meta_data=None,
|
| 30 |
+
|
| 31 |
+
crop_size=256,
|
| 32 |
+
atom_crop_size=256 * 8,
|
| 33 |
+
|
| 34 |
+
# Infer config
|
| 35 |
+
inference_mode=False,
|
| 36 |
+
infer_pocket_type="atom", # "ca"
|
| 37 |
+
infer_pocket_cutoff=6, # 8 10 12
|
| 38 |
+
infer_pocket_dist_type="ligand", # "ligand_centre"
|
| 39 |
+
infer_use_pocket=True,
|
| 40 |
+
infer_use_key_res=True,
|
| 41 |
+
|
| 42 |
+
# Train Config
|
| 43 |
+
train_pocket_type_atom_ratio=0.5,
|
| 44 |
+
train_pocket_cutoff_ligand_min=6,
|
| 45 |
+
train_pocket_cutoff_ligand_max=12,
|
| 46 |
+
train_pocket_cutoff_ligand_centre_min=10,
|
| 47 |
+
train_pocket_cutoff_ligand_centre_max=16,
|
| 48 |
+
train_pocket_dist_type_ligand_ratio=0.5,
|
| 49 |
+
train_use_pocket_ratio=0.5,
|
| 50 |
+
train_use_key_res_ratio=0.5,
|
| 51 |
+
|
| 52 |
+
train_shuffle_sym_id=True,
|
| 53 |
+
train_spatial_crop_ligand_ratio=0.2,
|
| 54 |
+
train_spatial_crop_interface_ratio=0.4,
|
| 55 |
+
train_spatial_crop_interface_threshold=15.,
|
| 56 |
+
train_charility_augmentation_ratio=0.1,
|
| 57 |
+
train_use_template_ratio=0.75,
|
| 58 |
+
train_template_mask_max_ratio=0.4,
|
| 59 |
+
|
| 60 |
+
# Other Configs
|
| 61 |
+
max_msa_clusters=128,
|
| 62 |
+
key_res_random_mask_ratio=0.5,
|
| 63 |
+
|
| 64 |
+
# Abalation
|
| 65 |
+
use_x_gt_ligand_as_ref_pos=False,
|
| 66 |
+
|
| 67 |
+
# Recycle
|
| 68 |
+
num_recycles=None
|
| 69 |
+
):
|
| 70 |
+
# Init Dataset
|
| 71 |
+
if dataset_path is not None:
|
| 72 |
+
self.msa_features_path = os.path.join(dataset_path, "msa_features")
|
| 73 |
+
self.uniprot_msa_features_path = os.path.join(dataset_path, "uniprot_msa_features")
|
| 74 |
+
if os.path.exists(os.path.join(dataset_path, "train_val")):
|
| 75 |
+
self.used_sample_ids = find_files(os.path.join(dataset_path, "train_val"))
|
| 76 |
+
if os.path.exists(os.path.join(dataset_path, "train_val_weights.json")):
|
| 77 |
+
weights = load_json(os.path.join(dataset_path, "train_val_weights.json"))
|
| 78 |
+
self.weights = np.array([weights[sample_id] for sample_id in self.used_sample_ids])
|
| 79 |
+
self.probabilities = torch.from_numpy(self.weights / self.weights.sum())
|
| 80 |
+
if msa_features_dir is not None:
|
| 81 |
+
self.msa_features_path = os.path.join(msa_features_dir, "msa_features")
|
| 82 |
+
self.uniprot_msa_features_path = os.path.join(msa_features_dir, "uniprot_msa_features")
|
| 83 |
+
# self.ccd_id_meta_data = load_pkl(
|
| 84 |
+
# "/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_meta_data_confs_chars.pkl.gz")
|
| 85 |
+
# self.ccd_id_ref_mol = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_dict.pkl.gz")
|
| 86 |
+
#
|
| 87 |
+
# # self.ccd_id_meta_data = load_pkl(
|
| 88 |
+
# # "/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_meta_data_confs_chars.pkl.gz")
|
| 89 |
+
# ddd_ccd_meta = load_pkl(
|
| 90 |
+
# "/2022133002/projects/stdock/stdock_v9.5/scripts/phi_ccd_meta_data_confs_chars.pkl.gz")
|
| 91 |
+
# self.ccd_id_meta_data.update(ddd_ccd_meta)
|
| 92 |
+
# self.ccd_id_ref_mol = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_dict.pkl.gz")
|
| 93 |
+
# ccd_id_ref_mol_phi = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/phi_ccd_dict.pkl.gz")
|
| 94 |
+
# self.ccd_id_ref_mol.update(ccd_id_ref_mol_phi)
|
| 95 |
+
if ccd_id_meta_data is None:
|
| 96 |
+
assert os.path.exists(os.path.join(dataset_path, "ccd_id_meta_data.pkl.gz")) or ccd_id_meta_data is not None
|
| 97 |
+
self.ccd_id_meta_data = load_pkl(os.path.join(dataset_path, "ccd_id_meta_data.pkl.gz"))
|
| 98 |
+
else:
|
| 99 |
+
self.ccd_id_meta_data = load_pkl(ccd_id_meta_data)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Inference Config
|
| 103 |
+
self.inference_mode = inference_mode
|
| 104 |
+
self.infer_use_pocket = infer_use_pocket
|
| 105 |
+
self.infer_use_key_res = infer_use_key_res
|
| 106 |
+
self.infer_pocket_type = infer_pocket_type
|
| 107 |
+
self.infer_pocket_cutoff = infer_pocket_cutoff
|
| 108 |
+
self.infer_pocket_dist_type = infer_pocket_dist_type
|
| 109 |
+
|
| 110 |
+
# Training Config
|
| 111 |
+
self.train_pocket_type_atom_ratio = train_pocket_type_atom_ratio
|
| 112 |
+
self.train_pocket_cutoff_ligand_min = train_pocket_cutoff_ligand_min
|
| 113 |
+
self.train_pocket_cutoff_ligand_max = train_pocket_cutoff_ligand_max
|
| 114 |
+
self.train_pocket_cutoff_ligand_centre_min = train_pocket_cutoff_ligand_centre_min
|
| 115 |
+
self.train_pocket_cutoff_ligand_centre_max = train_pocket_cutoff_ligand_centre_max
|
| 116 |
+
self.train_pocket_dist_type_ligand_ratio = train_pocket_dist_type_ligand_ratio
|
| 117 |
+
self.train_use_pocket_ratio = train_use_pocket_ratio
|
| 118 |
+
self.train_use_key_res_ratio = train_use_key_res_ratio
|
| 119 |
+
self.train_shuffle_sym_id = train_shuffle_sym_id
|
| 120 |
+
self.train_spatial_crop_ligand_ratio = train_spatial_crop_ligand_ratio
|
| 121 |
+
self.train_spatial_crop_interface_ratio = train_spatial_crop_interface_ratio
|
| 122 |
+
self.train_spatial_crop_interface_threshold = train_spatial_crop_interface_threshold
|
| 123 |
+
self.train_charility_augmentation_ratio = train_charility_augmentation_ratio
|
| 124 |
+
self.train_use_template_ratio = train_use_template_ratio
|
| 125 |
+
self.train_template_mask_max_ratio = train_template_mask_max_ratio
|
| 126 |
+
|
| 127 |
+
# Other Configs
|
| 128 |
+
self.token_bond_threshold = 2.4
|
| 129 |
+
self.key_res_random_mask_ratio = key_res_random_mask_ratio
|
| 130 |
+
self.crop_size = crop_size
|
| 131 |
+
self.atom_crop_size = atom_crop_size
|
| 132 |
+
self.max_msa_clusters = max_msa_clusters
|
| 133 |
+
|
| 134 |
+
self.use_x_gt_ligand_as_ref_pos = use_x_gt_ligand_as_ref_pos
|
| 135 |
+
|
| 136 |
+
self.num_recycles = num_recycles
|
| 137 |
+
|
| 138 |
+
def _update_CONF_META_DATA(self, CONF_META_DATA, ccds):
|
| 139 |
+
for ccd in ccds:
|
| 140 |
+
if ccd in CONF_META_DATA:
|
| 141 |
+
continue
|
| 142 |
+
ccd_features = self.ccd_id_meta_data[ccd]
|
| 143 |
+
ref_pos = ccd_features["ref_pos"]
|
| 144 |
+
ref_pos = ref_pos - np.mean(ref_pos, axis=0, keepdims=True)
|
| 145 |
+
CONF_META_DATA[ccd] = {
|
| 146 |
+
"ref_feat": np.concatenate([
|
| 147 |
+
ref_pos,
|
| 148 |
+
ccd_features["ref_charge"][..., None],
|
| 149 |
+
rc.eye_128[ccd_features["ref_element"]].astype(np.float32),
|
| 150 |
+
ccd_features["ref_is_aromatic"].astype(np.float32)[..., None],
|
| 151 |
+
rc.eye_9[ccd_features["ref_degree"]].astype(np.float32),
|
| 152 |
+
rc.eye_7[ccd_features["ref_hybridization"]].astype(np.float32),
|
| 153 |
+
rc.eye_9[ccd_features["ref_implicit_valence"]].astype(np.float32),
|
| 154 |
+
rc.eye_3[ccd_features["ref_chirality"]].astype(np.float32),
|
| 155 |
+
ccd_features["ref_in_ring_of_3"].astype(np.float32)[..., None],
|
| 156 |
+
ccd_features["ref_in_ring_of_4"].astype(np.float32)[..., None],
|
| 157 |
+
ccd_features["ref_in_ring_of_5"].astype(np.float32)[..., None],
|
| 158 |
+
ccd_features["ref_in_ring_of_6"].astype(np.float32)[..., None],
|
| 159 |
+
ccd_features["ref_in_ring_of_7"].astype(np.float32)[..., None],
|
| 160 |
+
ccd_features["ref_in_ring_of_8"].astype(np.float32)[..., None],
|
| 161 |
+
], axis=-1),
|
| 162 |
+
"rel_tok_feat": np.concatenate([
|
| 163 |
+
rc.eye_32[ccd_features["d_token"]].astype(np.float32),
|
| 164 |
+
rc.eye_5[ccd_features["bond_type"]].astype(np.float32),
|
| 165 |
+
ccd_features["token_bonds"].astype(np.float32)[..., None],
|
| 166 |
+
ccd_features["bond_as_double"].astype(np.float32)[..., None],
|
| 167 |
+
ccd_features["bond_in_ring"].astype(np.float32)[..., None],
|
| 168 |
+
ccd_features["bond_is_conjugated"].astype(np.float32)[..., None],
|
| 169 |
+
ccd_features["bond_is_aromatic"].astype(np.float32)[..., None],
|
| 170 |
+
], axis=-1),
|
| 171 |
+
"ref_atom_name_chars": ccd_features["ref_atom_name_chars"],
|
| 172 |
+
"ref_element": ccd_features["ref_element"],
|
| 173 |
+
"token_bonds": ccd_features["token_bonds"],
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
return CONF_META_DATA
|
| 177 |
+
|
| 178 |
+
def _update_chain_feature(self, chain_feature, CONF_META_DATA, use_pocket, use_key_res, ):
|
| 179 |
+
ccds_ori = chain_feature["ccds"]
|
| 180 |
+
chain_class = chain_feature["chain_class"]
|
| 181 |
+
if chain_class == "protein":
|
| 182 |
+
sequence = "".join([protein_letters_3to1_extended.get(ccd, "X") for ccd in ccds_ori])
|
| 183 |
+
md5 = convert_md5_string(f"protein:{sequence}")
|
| 184 |
+
# with open("add_msa.fasta", "a") as f:
|
| 185 |
+
# f.write(f">{md5}\n{sequence}\n")
|
| 186 |
+
try:
|
| 187 |
+
# import shutil
|
| 188 |
+
# shutil.copy(
|
| 189 |
+
# os.path.join(self.msa_features_path, f"{md5}.pkl.gz"),
|
| 190 |
+
# os.path.join("/home/zhangkexin/research/PhysDock/examples/demo/features/msa_features",
|
| 191 |
+
# f"{md5}.pkl.gz")
|
| 192 |
+
# )
|
| 193 |
+
# shutil.copy(
|
| 194 |
+
# os.path.join(self.uniprot_msa_features_path, f"{md5}.pkl.gz"),
|
| 195 |
+
# os.path.join("/home/zhangkexin/research/PhysDock/examples/demo/features/uniprot_msa_features",
|
| 196 |
+
# f"{md5}.pkl.gz")
|
| 197 |
+
# )
|
| 198 |
+
chain_feature.update(
|
| 199 |
+
load_pkl(os.path.join(self.msa_features_path, f"{md5}.pkl.gz"))
|
| 200 |
+
)
|
| 201 |
+
except:
|
| 202 |
+
print(f"Can't find msa feature!!! md5: {md5}")
|
| 203 |
+
with open("add_msa.fasta", "a") as f:
|
| 204 |
+
f.write(f">{md5}\n{sequence}\n")
|
| 205 |
+
|
| 206 |
+
chain_feature.update(
|
| 207 |
+
load_pkl(os.path.join(self.uniprot_msa_features_path, f"{md5}.pkl.gz"))
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
chain_feature["msa"] = np.array([[rc.standard_ccds.index(ccd)
|
| 211 |
+
if ccd in rc.standard_ccds else 20 for ccd in ccds_ori]] * 2,
|
| 212 |
+
dtype=np.int8)
|
| 213 |
+
chain_feature["deletion_matrix"] = np.zeros_like(chain_feature["msa"])
|
| 214 |
+
|
| 215 |
+
# Merge Key Res Feat & Augmentation
|
| 216 |
+
if "salt bridges" in chain_feature and use_key_res:
|
| 217 |
+
key_res_feat = np.stack([
|
| 218 |
+
chain_feature["salt bridges"],
|
| 219 |
+
chain_feature["pi-cation interactions"],
|
| 220 |
+
chain_feature["hydrophobic interactions"],
|
| 221 |
+
chain_feature["pi-stacking"],
|
| 222 |
+
chain_feature["hydrogen bonds"],
|
| 223 |
+
chain_feature["metal complexes"],
|
| 224 |
+
np.zeros_like(chain_feature["salt bridges"]),
|
| 225 |
+
], axis=-1).astype(np.float32)
|
| 226 |
+
else:
|
| 227 |
+
key_res_feat = np.zeros([len(ccds_ori), 7], dtype=np.float32)
|
| 228 |
+
is_key_res = np.any(key_res_feat.astype(np.bool_), axis=-1).astype(np.float32)
|
| 229 |
+
# Augmentation
|
| 230 |
+
# if not self.inference_mode:
|
| 231 |
+
key_res_feat = (key_res_feat *
|
| 232 |
+
(np.random.random([len(ccds_ori), 7]) > self.key_res_random_mask_ratio))
|
| 233 |
+
if "pocket_res_feat" in chain_feature and use_pocket:
|
| 234 |
+
pocket_res_feat = chain_feature["pocket_res_feat"]
|
| 235 |
+
else:
|
| 236 |
+
pocket_res_feat = np.zeros([len(ccds_ori)], dtype=np.float32)
|
| 237 |
+
x_gt = []
|
| 238 |
+
atom_id_to_conformer_atom_id = []
|
| 239 |
+
|
| 240 |
+
# Conformer
|
| 241 |
+
conformer_id_to_chunk_sizes = []
|
| 242 |
+
residue_index = []
|
| 243 |
+
restype = []
|
| 244 |
+
ccds = []
|
| 245 |
+
|
| 246 |
+
conformer_exists = []
|
| 247 |
+
|
| 248 |
+
for c_id, ccd in enumerate(chain_feature["ccds"]):
|
| 249 |
+
no_atom_this_conf = len(CONF_META_DATA[ccd]["ref_feat"])
|
| 250 |
+
conformer_atom_ids_this_conf = np.arange(no_atom_this_conf)
|
| 251 |
+
|
| 252 |
+
x_gt_this_conf = chain_feature["all_atom_positions"][c_id]
|
| 253 |
+
x_exists_this_conf = chain_feature["all_atom_mask"][c_id].astype(np.bool_)
|
| 254 |
+
|
| 255 |
+
# TODO DEBUG
|
| 256 |
+
# conformer_exist = np.sum(x_exists_this_conf).item() > len(x_exists_this_conf) - 2
|
| 257 |
+
conformer_exist = np.any(x_exists_this_conf).item()
|
| 258 |
+
|
| 259 |
+
if rc.is_standard(ccd):
|
| 260 |
+
conformer_exist = conformer_exist and x_exists_this_conf[1]
|
| 261 |
+
if ccd != "GLY":
|
| 262 |
+
conformer_exist = conformer_exist and x_exists_this_conf[4]
|
| 263 |
+
|
| 264 |
+
conformer_exists.append(conformer_exist)
|
| 265 |
+
if conformer_exist:
|
| 266 |
+
# Atomwise
|
| 267 |
+
x_gt.append(x_gt_this_conf[x_exists_this_conf])
|
| 268 |
+
atom_id_to_conformer_atom_id.append(conformer_atom_ids_this_conf[x_exists_this_conf])
|
| 269 |
+
# Tokenwise
|
| 270 |
+
residue_index.append(c_id)
|
| 271 |
+
conformer_id_to_chunk_sizes.append(np.sum(x_exists_this_conf).item())
|
| 272 |
+
restype.append(rc.standard_ccds.index(ccd) if ccd in rc.standard_ccds else 20)
|
| 273 |
+
ccds.append(ccd)
|
| 274 |
+
# print(ccds)
|
| 275 |
+
# print("x_gt", x_gt)
|
| 276 |
+
x_gt = np.concatenate(x_gt, axis=0)
|
| 277 |
+
atom_id_to_conformer_atom_id = np.concatenate(atom_id_to_conformer_atom_id, axis=0, dtype=np.int32)
|
| 278 |
+
residue_index = np.array(residue_index, dtype=np.int64)
|
| 279 |
+
conformer_id_to_chunk_sizes = np.array(conformer_id_to_chunk_sizes, dtype=np.int64)
|
| 280 |
+
restype = np.array(restype, dtype=np.int64)
|
| 281 |
+
|
| 282 |
+
conformer_exists = np.array(conformer_exists, dtype=np.bool_)
|
| 283 |
+
|
| 284 |
+
chain_feature_update = {
|
| 285 |
+
"x_gt": x_gt,
|
| 286 |
+
"atom_id_to_conformer_atom_id": atom_id_to_conformer_atom_id,
|
| 287 |
+
"residue_index": residue_index,
|
| 288 |
+
"conformer_id_to_chunk_sizes": conformer_id_to_chunk_sizes,
|
| 289 |
+
"restype": restype,
|
| 290 |
+
"ccds": ccds,
|
| 291 |
+
"msa": chain_feature["msa"][:, conformer_exists],
|
| 292 |
+
"deletion_matrix": chain_feature["deletion_matrix"][:, conformer_exists],
|
| 293 |
+
"chain_class": chain_class,
|
| 294 |
+
"key_res_feat": key_res_feat[conformer_exists],
|
| 295 |
+
"is_key_res": is_key_res[conformer_exists],
|
| 296 |
+
"pocket_res_feat": pocket_res_feat[conformer_exists],
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
chain_feature_update["is_protein"] = np.array([chain_class == "protein"] * len(ccds)).astype(np.float32)
|
| 300 |
+
chain_feature_update["is_ligand"] = np.array([chain_class != "protein"] * len(ccds)).astype(np.float32)
|
| 301 |
+
# Assert Short Poly Chain like peptide
|
| 302 |
+
chain_feature_update["is_short_poly"] = np.array(
|
| 303 |
+
[chain_class != "protein" and len(ccds) >= 2 and rc.is_standard(ccd) for ccd in ccds]
|
| 304 |
+
).astype(np.float32)
|
| 305 |
+
|
| 306 |
+
if "msa_all_seq" in chain_feature:
|
| 307 |
+
chain_feature_update["msa_all_seq"] = chain_feature["msa_all_seq"][:, conformer_exists]
|
| 308 |
+
chain_feature_update["deletion_matrix_all_seq"] = \
|
| 309 |
+
chain_feature["deletion_matrix_all_seq"][:, conformer_exists]
|
| 310 |
+
chain_feature_update["msa_species_identifiers_all_seq"] = chain_feature["msa_species_identifiers_all_seq"]
|
| 311 |
+
del chain_feature
|
| 312 |
+
return chain_feature_update
|
| 313 |
+
|
| 314 |
+
def _update_smi(self, smi, all_chain_labels, CONF_META_DATA):
|
| 315 |
+
ccd = "XXX"
|
| 316 |
+
chain_id = "99"
|
| 317 |
+
label_feature, conf_feature, ref_mol = get_features_from_smi(smi)
|
| 318 |
+
all_chain_labels[chain_id] = {
|
| 319 |
+
"all_atom_positions": label_feature["x_gt"][None],
|
| 320 |
+
"all_atom_mask": label_feature["x_exists"][None],
|
| 321 |
+
"ccds": [ccd]
|
| 322 |
+
}
|
| 323 |
+
ref_atom_name_chars = []
|
| 324 |
+
for id, ele in enumerate(conf_feature["ref_element"]):
|
| 325 |
+
atom_name = f"{PeriodicTable[ele] + str(id):<4}"
|
| 326 |
+
ref_atom_name_chars.append(atom_name)
|
| 327 |
+
CONF_META_DATA[ccd] = {
|
| 328 |
+
"ref_feat": np.concatenate([
|
| 329 |
+
conf_feature["ref_pos"],
|
| 330 |
+
conf_feature["ref_charge"][..., None],
|
| 331 |
+
rc.eye_128[conf_feature["ref_element"]].astype(np.float32),
|
| 332 |
+
conf_feature["ref_is_aromatic"].astype(np.float32)[..., None],
|
| 333 |
+
rc.eye_9[conf_feature["ref_degree"]].astype(np.float32),
|
| 334 |
+
rc.eye_7[conf_feature["ref_hybridization"]].astype(np.float32),
|
| 335 |
+
rc.eye_9[conf_feature["ref_implicit_valence"]].astype(np.float32),
|
| 336 |
+
rc.eye_3[conf_feature["ref_chirality"]].astype(np.float32),
|
| 337 |
+
conf_feature["ref_in_ring_of_3"].astype(np.float32)[..., None],
|
| 338 |
+
conf_feature["ref_in_ring_of_4"].astype(np.float32)[..., None],
|
| 339 |
+
conf_feature["ref_in_ring_of_5"].astype(np.float32)[..., None],
|
| 340 |
+
conf_feature["ref_in_ring_of_6"].astype(np.float32)[..., None],
|
| 341 |
+
conf_feature["ref_in_ring_of_7"].astype(np.float32)[..., None],
|
| 342 |
+
conf_feature["ref_in_ring_of_8"].astype(np.float32)[..., None],
|
| 343 |
+
], axis=-1),
|
| 344 |
+
"rel_tok_feat": np.concatenate([
|
| 345 |
+
rc.eye_32[conf_feature["d_token"]].astype(np.float32),
|
| 346 |
+
rc.eye_5[conf_feature["bond_type"]].astype(np.float32),
|
| 347 |
+
conf_feature["token_bonds"].astype(np.float32)[..., None],
|
| 348 |
+
conf_feature["bond_as_double"].astype(np.float32)[..., None],
|
| 349 |
+
conf_feature["bond_in_ring"].astype(np.float32)[..., None],
|
| 350 |
+
conf_feature["bond_is_conjugated"].astype(np.float32)[..., None],
|
| 351 |
+
conf_feature["bond_is_aromatic"].astype(np.float32)[..., None],
|
| 352 |
+
], axis=-1),
|
| 353 |
+
"ref_atom_name_chars": ref_atom_name_chars,
|
| 354 |
+
"ref_element": conf_feature["ref_element"],
|
| 355 |
+
"token_bonds": conf_feature["token_bonds"],
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
return all_chain_labels, CONF_META_DATA, ref_mol
|
| 359 |
+
|
| 360 |
+
def _add_assembly_feature(self, all_chain_features, SEQ3, ASYM_ID):
|
| 361 |
+
entities = {}
|
| 362 |
+
for chain_id, seq3 in SEQ3.items():
|
| 363 |
+
if seq3 not in entities:
|
| 364 |
+
entities[seq3] = [chain_id]
|
| 365 |
+
else:
|
| 366 |
+
entities[seq3].append(chain_id)
|
| 367 |
+
|
| 368 |
+
asym_id = 0
|
| 369 |
+
for entity_id, seq3 in enumerate(list(entities.keys())):
|
| 370 |
+
chain_ids = copy.deepcopy(entities[seq3])
|
| 371 |
+
if not self.inference_mode and self.train_shuffle_sym_id:
|
| 372 |
+
# sym_id augmentation
|
| 373 |
+
random.shuffle(chain_ids)
|
| 374 |
+
for sym_id, chain_id in enumerate(chain_ids):
|
| 375 |
+
num_conformers = len(all_chain_features[chain_id]["ccds"])
|
| 376 |
+
all_chain_features[chain_id]["asym_id"] = \
|
| 377 |
+
np.full([num_conformers], fill_value=asym_id, dtype=np.int32)
|
| 378 |
+
all_chain_features[chain_id]["sym_id"] = \
|
| 379 |
+
np.full([num_conformers], fill_value=sym_id, dtype=np.int32)
|
| 380 |
+
all_chain_features[chain_id]["entity_id"] = \
|
| 381 |
+
np.full([num_conformers], fill_value=entity_id, dtype=np.int32)
|
| 382 |
+
|
| 383 |
+
all_chain_features[chain_id]["sequence_3"] = seq3
|
| 384 |
+
ASYM_ID[asym_id] = chain_id
|
| 385 |
+
|
| 386 |
+
asym_id += 1
|
| 387 |
+
return all_chain_features, ASYM_ID
|
| 388 |
+
|
| 389 |
+
def _crop_all_chain_features(self, all_chain_features, infer_meta_data, crop_centre=None):
|
| 390 |
+
CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
|
| 391 |
+
ordered_chain_ids = list(all_chain_features.keys())
|
| 392 |
+
|
| 393 |
+
x_gt = np.concatenate([all_chain_features[chain_id]["x_gt"] for chain_id in ordered_chain_ids], axis=0)
|
| 394 |
+
|
| 395 |
+
token_id_to_centre_atom_id = []
|
| 396 |
+
token_id_to_conformer_id = []
|
| 397 |
+
token_id_to_ccd_chunk_sizes = []
|
| 398 |
+
token_id_to_ccd = []
|
| 399 |
+
asym_id_ca = []
|
| 400 |
+
token_id = 0
|
| 401 |
+
atom_id = 0
|
| 402 |
+
conf_id = 0
|
| 403 |
+
x_gt_ligand = []
|
| 404 |
+
for chain_id in ordered_chain_ids:
|
| 405 |
+
if chain_id.isdigit() and len(all_chain_features[chain_id]["ccds"]) == 1:
|
| 406 |
+
x_gt_ligand.append(all_chain_features[chain_id]["x_gt"])
|
| 407 |
+
atom_offset = 0
|
| 408 |
+
for ccd, chunk_size_this_ccd, asym_id in zip(
|
| 409 |
+
all_chain_features[chain_id]["ccds"],
|
| 410 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
|
| 411 |
+
all_chain_features[chain_id]["asym_id"],
|
| 412 |
+
):
|
| 413 |
+
inner_atom_idx = all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][
|
| 414 |
+
atom_offset:atom_offset + chunk_size_this_ccd]
|
| 415 |
+
atom_names = [CONF_META_DATA[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 416 |
+
if rc.is_standard(ccd):
|
| 417 |
+
|
| 418 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 419 |
+
if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
|
| 420 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 421 |
+
token_id_to_conformer_id.append(conf_id)
|
| 422 |
+
token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
|
| 423 |
+
token_id_to_ccd.append(ccd)
|
| 424 |
+
asym_id_ca.append(asym_id)
|
| 425 |
+
atom_id += 1
|
| 426 |
+
token_id += 1
|
| 427 |
+
|
| 428 |
+
else:
|
| 429 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 430 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 431 |
+
token_id_to_conformer_id.append(conf_id)
|
| 432 |
+
token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
|
| 433 |
+
token_id_to_ccd.append(ccd)
|
| 434 |
+
asym_id_ca.append(asym_id)
|
| 435 |
+
atom_id += 1
|
| 436 |
+
token_id += 1
|
| 437 |
+
atom_offset += chunk_size_this_ccd
|
| 438 |
+
conf_id += 1
|
| 439 |
+
|
| 440 |
+
x_gt_ca = x_gt[token_id_to_centre_atom_id]
|
| 441 |
+
asym_id_ca = np.array(asym_id_ca)
|
| 442 |
+
|
| 443 |
+
crop_scheme_seed = random.random()
|
| 444 |
+
|
| 445 |
+
# Crop Ligand Centre
|
| 446 |
+
if self.inference_mode and len(x_gt_ligand) == 1:
|
| 447 |
+
x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)
|
| 448 |
+
x_gt_sel = np.mean(x_gt_ligand, axis=0)[None]
|
| 449 |
+
|
| 450 |
+
# Spatial Crop Ligand
|
| 451 |
+
elif crop_scheme_seed < (self.train_spatial_crop_ligand_ratio if not self.inference_mode else 1.0) and len(
|
| 452 |
+
x_gt_ligand) > 0:
|
| 453 |
+
x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)
|
| 454 |
+
x_gt_sel = random.choice(x_gt_ligand)[None]
|
| 455 |
+
# Spatial Crop Interface
|
| 456 |
+
elif crop_scheme_seed < self.train_spatial_crop_ligand_ratio + self.train_spatial_crop_interface_ratio and len(
|
| 457 |
+
set(asym_id_ca.tolist())) > 1:
|
| 458 |
+
chain_same = asym_id_ca[None] * asym_id_ca[:, None]
|
| 459 |
+
dist = np.linalg.norm(x_gt_ca[:, None] - x_gt_ca[None], axis=-1)
|
| 460 |
+
|
| 461 |
+
dist = dist + chain_same * 100
|
| 462 |
+
# interface_threshold
|
| 463 |
+
mask = np.any(dist < self.train_spatial_crop_interface_threshold, axis=-1)
|
| 464 |
+
if sum(mask) > 0:
|
| 465 |
+
x_gt_sel = random.choice(x_gt_ca[mask])[None]
|
| 466 |
+
else:
|
| 467 |
+
x_gt_sel = random.choice(x_gt_ca)[None]
|
| 468 |
+
# Spatial Crop
|
| 469 |
+
else:
|
| 470 |
+
x_gt_sel = random.choice(x_gt_ca)[None]
|
| 471 |
+
dist = np.linalg.norm(x_gt_ca - x_gt_sel, axis=-1)
|
| 472 |
+
token_idxs = np.argsort(dist)
|
| 473 |
+
|
| 474 |
+
select_ccds_idx = []
|
| 475 |
+
to_sum_atom = 0
|
| 476 |
+
to_sum_token = 0
|
| 477 |
+
for token_idx in token_idxs:
|
| 478 |
+
ccd_idx = token_id_to_conformer_id[token_idx]
|
| 479 |
+
ccd_chunk_size = token_id_to_ccd_chunk_sizes[token_idx]
|
| 480 |
+
ccd_this_token = token_id_to_ccd[token_idx]
|
| 481 |
+
if ccd_idx in select_ccds_idx:
|
| 482 |
+
continue
|
| 483 |
+
if to_sum_atom + ccd_chunk_size > self.atom_crop_size:
|
| 484 |
+
break
|
| 485 |
+
to_add_token = 1 if rc.is_standard(ccd_this_token) else ccd_chunk_size
|
| 486 |
+
if to_sum_token + to_add_token > self.crop_size:
|
| 487 |
+
break
|
| 488 |
+
select_ccds_idx.append(ccd_idx)
|
| 489 |
+
to_sum_atom += ccd_chunk_size
|
| 490 |
+
to_sum_token += to_add_token
|
| 491 |
+
|
| 492 |
+
ccd_all_id = 0
|
| 493 |
+
crop_chains = []
|
| 494 |
+
for chain_id in ordered_chain_ids:
|
| 495 |
+
conformer_used_mask = []
|
| 496 |
+
atom_used_mask = []
|
| 497 |
+
ccds = []
|
| 498 |
+
for ccd, chunk_size_this_ccd in zip(
|
| 499 |
+
all_chain_features[chain_id]["ccds"],
|
| 500 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
|
| 501 |
+
):
|
| 502 |
+
if ccd_all_id in select_ccds_idx:
|
| 503 |
+
ccds.append(ccd)
|
| 504 |
+
if chain_id not in crop_chains:
|
| 505 |
+
crop_chains.append(chain_id)
|
| 506 |
+
conformer_used_mask.append(ccd_all_id in select_ccds_idx)
|
| 507 |
+
atom_used_mask += [ccd_all_id in select_ccds_idx] * chunk_size_this_ccd
|
| 508 |
+
ccd_all_id += 1
|
| 509 |
+
conf_mask = np.array(conformer_used_mask).astype(np.bool_)
|
| 510 |
+
atom_mask = np.array(atom_used_mask).astype(np.bool_)
|
| 511 |
+
# Update All Chain Features
|
| 512 |
+
all_chain_features[chain_id]["x_gt"] = all_chain_features[chain_id]["x_gt"][atom_mask]
|
| 513 |
+
all_chain_features[chain_id]["atom_id_to_conformer_atom_id"] = \
|
| 514 |
+
all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][atom_mask]
|
| 515 |
+
all_chain_features[chain_id]["restype"] = all_chain_features[chain_id]["restype"][conf_mask]
|
| 516 |
+
all_chain_features[chain_id]["residue_index"] = all_chain_features[chain_id]["residue_index"][conf_mask]
|
| 517 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = \
|
| 518 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"][conf_mask]
|
| 519 |
+
# BUG Fix
|
| 520 |
+
all_chain_features[chain_id]["key_res_feat"] = all_chain_features[chain_id]["key_res_feat"][conf_mask]
|
| 521 |
+
all_chain_features[chain_id]["pocket_res_feat"] = all_chain_features[chain_id]["pocket_res_feat"][conf_mask]
|
| 522 |
+
all_chain_features[chain_id]["is_key_res"] = all_chain_features[chain_id]["is_key_res"][conf_mask]
|
| 523 |
+
all_chain_features[chain_id]["is_protein"] = all_chain_features[chain_id]["is_protein"][conf_mask]
|
| 524 |
+
all_chain_features[chain_id]["is_short_poly"] = all_chain_features[chain_id]["is_short_poly"][conf_mask]
|
| 525 |
+
all_chain_features[chain_id]["is_ligand"] = all_chain_features[chain_id]["is_ligand"][conf_mask]
|
| 526 |
+
all_chain_features[chain_id]["asym_id"] = all_chain_features[chain_id]["asym_id"][conf_mask]
|
| 527 |
+
all_chain_features[chain_id]["sym_id"] = all_chain_features[chain_id]["sym_id"][conf_mask]
|
| 528 |
+
all_chain_features[chain_id]["entity_id"] = all_chain_features[chain_id]["entity_id"][conf_mask]
|
| 529 |
+
|
| 530 |
+
all_chain_features[chain_id]["ccds"] = ccds
|
| 531 |
+
if "msa" in all_chain_features[chain_id]:
|
| 532 |
+
all_chain_features[chain_id]["msa"] = all_chain_features[chain_id]["msa"][:, conf_mask]
|
| 533 |
+
all_chain_features[chain_id]["deletion_matrix"] = \
|
| 534 |
+
all_chain_features[chain_id]["deletion_matrix"][:, conf_mask]
|
| 535 |
+
if "msa_all_seq" in all_chain_features[chain_id]:
|
| 536 |
+
all_chain_features[chain_id]["msa_all_seq"] = all_chain_features[chain_id]["msa_all_seq"][:, conf_mask]
|
| 537 |
+
all_chain_features[chain_id]["deletion_matrix_all_seq"] = \
|
| 538 |
+
all_chain_features[chain_id]["deletion_matrix_all_seq"][:, conf_mask]
|
| 539 |
+
# Remove Unused Chains
|
| 540 |
+
for chain_id in list(all_chain_features.keys()):
|
| 541 |
+
if chain_id not in crop_chains:
|
| 542 |
+
all_chain_features.pop(chain_id, None)
|
| 543 |
+
return all_chain_features, infer_meta_data
|
| 544 |
+
|
| 545 |
+
def _make_ccd_features(self, raw_feats, infer_meta_data):
|
| 546 |
+
CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
|
| 547 |
+
ccds = raw_feats["ccds"]
|
| 548 |
+
atom_id_to_conformer_atom_id = raw_feats["atom_id_to_conformer_atom_id"]
|
| 549 |
+
conformer_id_to_chunk_sizes = raw_feats["conformer_id_to_chunk_sizes"]
|
| 550 |
+
|
| 551 |
+
# Atomwise
|
| 552 |
+
atom_id_to_conformer_id = []
|
| 553 |
+
atom_id_to_token_id = []
|
| 554 |
+
ref_feat = []
|
| 555 |
+
|
| 556 |
+
# Tokenwise
|
| 557 |
+
s_mask = []
|
| 558 |
+
token_id_to_conformer_id = []
|
| 559 |
+
token_id_to_chunk_sizes = []
|
| 560 |
+
token_id_to_centre_atom_id = []
|
| 561 |
+
token_id_to_pseudo_beta_atom_id = []
|
| 562 |
+
|
| 563 |
+
token_id = 0
|
| 564 |
+
atom_id = 0
|
| 565 |
+
for conf_id, (ccd, ccd_atoms) in enumerate(zip(ccds, conformer_id_to_chunk_sizes)):
|
| 566 |
+
conf_meta_data = CONF_META_DATA[ccd]
|
| 567 |
+
# UNK Conformer
|
| 568 |
+
if rc.is_unk(ccd):
|
| 569 |
+
s_mask.append(0)
|
| 570 |
+
token_id_to_chunk_sizes.append(0)
|
| 571 |
+
token_id_to_conformer_id.append(conf_id)
|
| 572 |
+
token_id_to_centre_atom_id.append(-1)
|
| 573 |
+
token_id_to_pseudo_beta_atom_id.append(-1)
|
| 574 |
+
token_id += 1
|
| 575 |
+
# Standard Residue
|
| 576 |
+
elif rc.is_standard(ccd):
|
| 577 |
+
inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
|
| 578 |
+
atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 579 |
+
ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
|
| 580 |
+
token_id_to_conformer_id.append(conf_id)
|
| 581 |
+
token_id_to_chunk_sizes.append(ccd_atoms.item())
|
| 582 |
+
s_mask.append(1)
|
| 583 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 584 |
+
# Update Atomwise Features
|
| 585 |
+
atom_id_to_conformer_id.append(conf_id)
|
| 586 |
+
atom_id_to_token_id.append(token_id)
|
| 587 |
+
# Update special atom ids
|
| 588 |
+
if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
|
| 589 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 590 |
+
if atom_name == rc.standard_ccd_to_token_pseudo_beta_atom_name[ccd]:
|
| 591 |
+
token_id_to_pseudo_beta_atom_id.append(atom_id)
|
| 592 |
+
atom_id += 1
|
| 593 |
+
token_id += 1
|
| 594 |
+
# Nonestandard Residue & Ligand
|
| 595 |
+
else:
|
| 596 |
+
inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
|
| 597 |
+
atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 598 |
+
ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
|
| 599 |
+
# ref_pos_new.append(conf_meta_data["ref_pos_new"][:, inner_atom_idx])
|
| 600 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 601 |
+
# Update Atomwise Features
|
| 602 |
+
atom_id_to_conformer_id.append(conf_id)
|
| 603 |
+
atom_id_to_token_id.append(token_id)
|
| 604 |
+
# Update Tokenwise Features
|
| 605 |
+
token_id_to_chunk_sizes.append(1)
|
| 606 |
+
token_id_to_conformer_id.append(conf_id)
|
| 607 |
+
s_mask.append(1)
|
| 608 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 609 |
+
token_id_to_pseudo_beta_atom_id.append(atom_id)
|
| 610 |
+
atom_id += 1
|
| 611 |
+
token_id += 1
|
| 612 |
+
|
| 613 |
+
if len(ref_feat) > 1:
|
| 614 |
+
ref_feat = np.concatenate(ref_feat, axis=0).astype(np.float32)
|
| 615 |
+
else:
|
| 616 |
+
ref_feat = ref_feat[0].astype(np.float32)
|
| 617 |
+
|
| 618 |
+
features = {
|
| 619 |
+
# Atomwise
|
| 620 |
+
"atom_id_to_conformer_id": np.array(atom_id_to_conformer_id, dtype=np.int64),
|
| 621 |
+
"atom_id_to_token_id": np.array(atom_id_to_token_id, dtype=np.int64),
|
| 622 |
+
"ref_feat": ref_feat,
|
| 623 |
+
# Tokewise
|
| 624 |
+
"token_id_to_conformer_id": np.array(token_id_to_conformer_id, dtype=np.int64),
|
| 625 |
+
"s_mask": np.array(s_mask, dtype=np.int64),
|
| 626 |
+
"token_id_to_centre_atom_id": np.array(token_id_to_centre_atom_id, dtype=np.int64),
|
| 627 |
+
"token_id_to_pseudo_beta_atom_id": np.array(token_id_to_pseudo_beta_atom_id, dtype=np.int64),
|
| 628 |
+
"token_id_to_chunk_sizes": np.array(token_id_to_chunk_sizes, dtype=np.int64),
|
| 629 |
+
}
|
| 630 |
+
features["ref_pos"] = features["ref_feat"][..., :3]
|
| 631 |
+
return features
|
| 632 |
+
|
| 633 |
+
def pair_and_merge(self, all_chain_features, infer_meta_data):
|
| 634 |
+
CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"] # Dict
|
| 635 |
+
CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
|
| 636 |
+
ASYM_ID = infer_meta_data["ASYM_ID"]
|
| 637 |
+
homo_feats = {}
|
| 638 |
+
|
| 639 |
+
all_chain_ids = list(all_chain_features.keys())
|
| 640 |
+
if len(all_chain_ids) == 1 and CHAIN_CLASS[all_chain_ids[0]] == "ligand":
|
| 641 |
+
ordered_chain_ids = all_chain_ids
|
| 642 |
+
raw_feats = all_chain_features[all_chain_ids[0]]
|
| 643 |
+
raw_feats["msa"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
|
| 644 |
+
raw_feats["deletion_matrix"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
|
| 645 |
+
keys = list(raw_feats.keys())
|
| 646 |
+
|
| 647 |
+
for feature_name in keys:
|
| 648 |
+
if feature_name not in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index",
|
| 649 |
+
"conformer_id_to_chunk_sizes", "restype", "is_protein", "is_short_poly",
|
| 650 |
+
"is_ligand",
|
| 651 |
+
"asym_id", "sym_id", "entity_id", "msa", "deletion_matrix", "ccds",
|
| 652 |
+
"pocket_res_feat", "key_res_feat", "is_key_res"]:
|
| 653 |
+
raw_feats.pop(feature_name)
|
| 654 |
+
|
| 655 |
+
# Update Profile and Deletion Mean
|
| 656 |
+
msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
|
| 657 |
+
raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
|
| 658 |
+
del msa_one_hot
|
| 659 |
+
raw_feats["deletion_mean"] = (torch.atan(
|
| 660 |
+
torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
|
| 661 |
+
) * (2.0 / torch.pi)).numpy()
|
| 662 |
+
else:
|
| 663 |
+
|
| 664 |
+
for chain_id in list(all_chain_features.keys()):
|
| 665 |
+
homo_feats[chain_id] = {
|
| 666 |
+
"asym_id": copy.deepcopy(all_chain_features[chain_id]["asym_id"]),
|
| 667 |
+
"sym_id": copy.deepcopy(all_chain_features[chain_id]["sym_id"]),
|
| 668 |
+
"entity_id": copy.deepcopy(all_chain_features[chain_id]["entity_id"]),
|
| 669 |
+
}
|
| 670 |
+
for chain_id in list(all_chain_features.keys()):
|
| 671 |
+
homo_feats[chain_id]["chain_class"] = all_chain_features[chain_id].pop("chain_class")
|
| 672 |
+
homo_feats[chain_id]["sequence_3"] = all_chain_features[chain_id].pop("sequence_3")
|
| 673 |
+
homo_feats[chain_id]["msa"] = all_chain_features[chain_id].pop("msa")
|
| 674 |
+
homo_feats[chain_id]["deletion_matrix"] = all_chain_features[chain_id].pop("deletion_matrix")
|
| 675 |
+
if "msa_all_seq" in all_chain_features[chain_id]:
|
| 676 |
+
homo_feats[chain_id]["msa_all_seq"] = all_chain_features[chain_id].pop("msa_all_seq")
|
| 677 |
+
homo_feats[chain_id]["deletion_matrix_all_seq"] = all_chain_features[chain_id].pop(
|
| 678 |
+
"deletion_matrix_all_seq")
|
| 679 |
+
homo_feats[chain_id]["msa_species_identifiers_all_seq"] = all_chain_features[chain_id].pop(
|
| 680 |
+
"msa_species_identifiers_all_seq")
|
| 681 |
+
|
| 682 |
+
# Initial raw feats with merged homo feats
|
| 683 |
+
raw_feats = pair_and_merge(homo_feats, is_homomer_or_monomer=False)
|
| 684 |
+
|
| 685 |
+
# Update Profile and Deletion Mean
|
| 686 |
+
msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
|
| 687 |
+
raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
|
| 688 |
+
del msa_one_hot
|
| 689 |
+
raw_feats["deletion_mean"] = (torch.atan(
|
| 690 |
+
torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
|
| 691 |
+
) * (2.0 / torch.pi)).numpy()
|
| 692 |
+
|
| 693 |
+
# Merge no homo feats according to asym_id
|
| 694 |
+
ordered_asym_ids = []
|
| 695 |
+
for i in raw_feats["asym_id"]:
|
| 696 |
+
if i not in ordered_asym_ids:
|
| 697 |
+
ordered_asym_ids.append(i)
|
| 698 |
+
ordered_chain_ids = [ASYM_ID[i] for i in ordered_asym_ids]
|
| 699 |
+
for feature_name in ["chain_class", "sequence_3", "assembly_num_chains", "entity_mask", "seq_length",
|
| 700 |
+
"num_alignments"]:
|
| 701 |
+
raw_feats.pop(feature_name, None)
|
| 702 |
+
|
| 703 |
+
for feature_name in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index", "conformer_id_to_chunk_sizes",
|
| 704 |
+
"restype", "is_protein", "is_short_poly", "is_ligand", "pocket_res_feat",
|
| 705 |
+
"key_res_feat", "is_key_res"]:
|
| 706 |
+
raw_feats[feature_name] = np.concatenate([
|
| 707 |
+
all_chain_features[chain_id].pop(feature_name) for chain_id in ordered_chain_ids
|
| 708 |
+
], axis=0)
|
| 709 |
+
|
| 710 |
+
# Conformerwise Chain Class
|
| 711 |
+
CHAIN_CLASS_NEW = []
|
| 712 |
+
for chain_id in ordered_chain_ids:
|
| 713 |
+
CHAIN_CLASS_NEW += [CHAIN_CLASS[chain_id]] * len(all_chain_features[chain_id]["ccds"])
|
| 714 |
+
infer_meta_data["CHAIN_CLASS"] = CHAIN_CLASS_NEW
|
| 715 |
+
|
| 716 |
+
raw_feats["ccds"] = reduce(add, [all_chain_features[chain_id].pop("ccds") for chain_id in ordered_chain_ids])
|
| 717 |
+
|
| 718 |
+
# Create Atomwise and Tokenwise Features
|
| 719 |
+
raw_feats.update(self._make_ccd_features(raw_feats, infer_meta_data))
|
| 720 |
+
if self.use_x_gt_ligand_as_ref_pos:
|
| 721 |
+
is_ligand_atom = raw_feats["is_ligand"][raw_feats["atom_id_to_conformer_id"]].astype(np.bool_)
|
| 722 |
+
raw_feats["ref_pos"][is_ligand_atom] = raw_feats["x_gt"][is_ligand_atom] - np.mean(
|
| 723 |
+
raw_feats["x_gt"][is_ligand_atom], axis=0, keepdims=True)
|
| 724 |
+
|
| 725 |
+
asym_id_conformerwise = copy.deepcopy(raw_feats["asym_id"])
|
| 726 |
+
residue_index_conformerwise = copy.deepcopy(raw_feats["residue_index"])
|
| 727 |
+
|
| 728 |
+
# Conformerwise to Tokenwise
|
| 729 |
+
token_id_to_conformer_id = raw_feats["token_id_to_conformer_id"]
|
| 730 |
+
for key in ["is_protein", "is_short_poly", "is_ligand", "residue_index", "restype", "asym_id", "entity_id",
|
| 731 |
+
"sym_id", "deletion_mean", "profile", "pocket_res_feat", "key_res_feat", "is_key_res"]:
|
| 732 |
+
raw_feats[key] = raw_feats[key][token_id_to_conformer_id]
|
| 733 |
+
for key in ["msa", "deletion_matrix"]:
|
| 734 |
+
if key in raw_feats:
|
| 735 |
+
raw_feats[key] = raw_feats[key][:, token_id_to_conformer_id]
|
| 736 |
+
###################################################
|
| 737 |
+
# Centre Random Augmentation of ref pos #
|
| 738 |
+
###################################################
|
| 739 |
+
# atom_id_to_token_id
|
| 740 |
+
# atom_id_to_conformer_id
|
| 741 |
+
raw_feats["ref_pos"] = centre_random_augmentation_np_apply(
|
| 742 |
+
raw_feats["ref_pos"], raw_feats["atom_id_to_conformer_id"]).astype(np.float32)
|
| 743 |
+
raw_feats["ref_feat"][:, :3] = raw_feats["ref_pos"]
|
| 744 |
+
|
| 745 |
+
###################################################
|
| 746 |
+
# Create token pair features #
|
| 747 |
+
###################################################
|
| 748 |
+
no_token = len(raw_feats["token_id_to_conformer_id"])
|
| 749 |
+
token_bonds = np.zeros([no_token, no_token], dtype=np.float32)
|
| 750 |
+
rel_tok_feat = np.zeros([no_token, no_token, 42], dtype=np.float32)
|
| 751 |
+
offset = 0
|
| 752 |
+
atom_offset = 0
|
| 753 |
+
for ccd, len_atoms in zip(
|
| 754 |
+
raw_feats["ccds"],
|
| 755 |
+
raw_feats["conformer_id_to_chunk_sizes"]
|
| 756 |
+
):
|
| 757 |
+
if rc.is_standard(ccd) or rc.is_unk(ccd):
|
| 758 |
+
offset += 1
|
| 759 |
+
else:
|
| 760 |
+
len_atoms = len_atoms.item()
|
| 761 |
+
inner_atom_idx = raw_feats["atom_id_to_conformer_atom_id"][atom_offset:atom_offset + len_atoms]
|
| 762 |
+
token_bonds[offset:offset + len_atoms, offset:offset + len_atoms] = \
|
| 763 |
+
CONF_META_DATA[ccd]["token_bonds"][inner_atom_idx][:, inner_atom_idx]
|
| 764 |
+
rel_tok_feat[offset:offset + len_atoms, offset:offset + len_atoms] = \
|
| 765 |
+
CONF_META_DATA[ccd]["rel_tok_feat"][inner_atom_idx][:, inner_atom_idx]
|
| 766 |
+
offset += len_atoms
|
| 767 |
+
atom_offset += len_atoms
|
| 768 |
+
raw_feats["token_bonds"] = token_bonds.astype(np.float32)
|
| 769 |
+
raw_feats["token_bonds_feature"] = token_bonds.astype(np.float32)
|
| 770 |
+
raw_feats["rel_tok_feat"] = rel_tok_feat.astype(np.float32)
|
| 771 |
+
###################################################
|
| 772 |
+
# Charility Augmentation #
|
| 773 |
+
###################################################
|
| 774 |
+
if not self.inference_mode:
|
| 775 |
+
# TODO Charility probs
|
| 776 |
+
charility_seed = random.random()
|
| 777 |
+
if charility_seed < self.train_charility_augmentation_ratio:
|
| 778 |
+
ref_chirality = raw_feats["ref_feat"][:, 158:161]
|
| 779 |
+
ref_chirality_replace = np.zeros_like(ref_chirality)
|
| 780 |
+
ref_chirality_replace[:, 2] = 1
|
| 781 |
+
|
| 782 |
+
is_ligand_atom = raw_feats["is_ligand"][raw_feats["atom_id_to_token_id"]]
|
| 783 |
+
remove_charility = (np.random.randint(0, 2, [len(is_ligand_atom)]) * is_ligand_atom).astype(
|
| 784 |
+
np.bool_)
|
| 785 |
+
ref_chirality = np.where(remove_charility[:, None], ref_chirality_replace, ref_chirality)
|
| 786 |
+
raw_feats["ref_feat"][:, 158:161] = ref_chirality
|
| 787 |
+
|
| 788 |
+
# MASKS
|
| 789 |
+
raw_feats["x_exists"] = np.ones_like(raw_feats["x_gt"][..., 0]).astype(np.float32)
|
| 790 |
+
raw_feats["a_mask"] = raw_feats["x_exists"]
|
| 791 |
+
raw_feats["s_mask"] = np.ones_like(raw_feats["asym_id"]).astype(np.float32)
|
| 792 |
+
raw_feats["ref_space_uid"] = raw_feats["atom_id_to_conformer_id"]
|
| 793 |
+
|
| 794 |
+
# Write Infer Meta Data
|
| 795 |
+
infer_meta_data["ccds"] = raw_feats.pop("ccds")
|
| 796 |
+
infer_meta_data["atom_id_to_conformer_atom_id"] = raw_feats.pop("atom_id_to_conformer_atom_id")
|
| 797 |
+
infer_meta_data["residue_index"] = residue_index_conformerwise
|
| 798 |
+
infer_meta_data["asym_id"] = asym_id_conformerwise
|
| 799 |
+
infer_meta_data["conformer_id_to_chunk_sizes"] = raw_feats.pop("conformer_id_to_chunk_sizes")
|
| 800 |
+
|
| 801 |
+
return raw_feats, infer_meta_data
|
| 802 |
+
|
| 803 |
+
def make_feats(self, tensors):
|
| 804 |
+
# Target Feat
|
| 805 |
+
tensors["target_feat"] = torch.cat([
|
| 806 |
+
F.one_hot(tensors["restype"].long(), 32).float(),
|
| 807 |
+
tensors["profile"].float(),
|
| 808 |
+
tensors["deletion_mean"][..., None].float()
|
| 809 |
+
], dim=-1)
|
| 810 |
+
|
| 811 |
+
if self.num_recycles is None:
|
| 812 |
+
# MSA Feat
|
| 813 |
+
inds = [0] + torch.randperm(len(tensors["msa"]))[:self.max_msa_clusters - 1].tolist()
|
| 814 |
+
|
| 815 |
+
tensors["msa"] = tensors["msa"][inds]
|
| 816 |
+
tensors["deletion_matrix"] = tensors["deletion_matrix"][inds]
|
| 817 |
+
|
| 818 |
+
has_deletion = torch.clamp(tensors["deletion_matrix"].float(), min=0., max=1.)
|
| 819 |
+
pi = torch.acos(torch.zeros(1, device=tensors["deletion_matrix"].device)) * 2
|
| 820 |
+
deletion_value = (torch.atan(tensors["deletion_matrix"] / 3.) * (2. / pi))
|
| 821 |
+
tensors["msa_feat"] = torch.cat([
|
| 822 |
+
F.one_hot(tensors["msa"].long(), 32).float(),
|
| 823 |
+
has_deletion[..., None].float(),
|
| 824 |
+
deletion_value[..., None].float(),
|
| 825 |
+
], dim=-1)
|
| 826 |
+
else:
|
| 827 |
+
batch_msa_feat = []
|
| 828 |
+
for i in range(self.num_recycles):
|
| 829 |
+
inds = [0] + torch.randperm(len(tensors["msa"]))[:self.max_msa_clusters - 1].tolist()
|
| 830 |
+
|
| 831 |
+
tensors["msa"] = tensors["msa"][inds]
|
| 832 |
+
tensors["deletion_matrix"] = tensors["deletion_matrix"][inds]
|
| 833 |
+
|
| 834 |
+
has_deletion = torch.clamp(tensors["deletion_matrix"].float(), min=0., max=1.)
|
| 835 |
+
pi = torch.acos(torch.zeros(1, device=tensors["deletion_matrix"].device)) * 2
|
| 836 |
+
deletion_value = (torch.atan(tensors["deletion_matrix"] / 3.) * (2. / pi))
|
| 837 |
+
msa_feat = torch.cat([
|
| 838 |
+
F.one_hot(tensors["msa"].long(), 32).float(),
|
| 839 |
+
has_deletion[..., None].float(),
|
| 840 |
+
deletion_value[..., None].float(),
|
| 841 |
+
], dim=-1)
|
| 842 |
+
batch_msa_feat.append(msa_feat)
|
| 843 |
+
tensors["msa_feat"] = batch_msa_feat[0]
|
| 844 |
+
tensors["batch_msa_feat"] = torch.stack(batch_msa_feat, dim=0)
|
| 845 |
+
|
| 846 |
+
tensors.pop("msa", None)
|
| 847 |
+
tensors.pop("deletion_mean", None)
|
| 848 |
+
tensors.pop("profile", None)
|
| 849 |
+
tensors.pop("deletion_matrix", None)
|
| 850 |
+
|
| 851 |
+
return tensors
|
| 852 |
+
|
| 853 |
+
def _make_token_bonds(self, tensors):
|
| 854 |
+
# Get Polymer-Ligand & Ligand-Ligand Within Conformer Token Bond
|
| 855 |
+
# Atomwise asym_id
|
| 856 |
+
asym_id = tensors["asym_id"][tensors["atom_id_to_token_id"]]
|
| 857 |
+
is_ligand = tensors["is_ligand"][tensors["atom_id_to_token_id"]]
|
| 858 |
+
|
| 859 |
+
x_gt = tensors["x_gt"]
|
| 860 |
+
a_mask = tensors["a_mask"]
|
| 861 |
+
|
| 862 |
+
# Get
|
| 863 |
+
atom_id_to_token_id = tensors["atom_id_to_token_id"]
|
| 864 |
+
|
| 865 |
+
num_token = len(tensors["asym_id"])
|
| 866 |
+
between_conformer_token_bonds = torch.zeros([num_token, num_token])
|
| 867 |
+
|
| 868 |
+
# create chainwise feature
|
| 869 |
+
asym_id_chain = []
|
| 870 |
+
asym_id_atom_offset = []
|
| 871 |
+
asym_id_is_ligand = []
|
| 872 |
+
for atom_offset, (a_id, i_id) in enumerate(zip(asym_id.tolist(), is_ligand.tolist())):
|
| 873 |
+
if len(asym_id_chain) == 0 or asym_id_chain[-1] != a_id:
|
| 874 |
+
asym_id_chain.append(a_id)
|
| 875 |
+
asym_id_atom_offset.append(atom_offset)
|
| 876 |
+
asym_id_is_ligand.append(i_id)
|
| 877 |
+
|
| 878 |
+
len_asym_id_chain = len(asym_id_chain)
|
| 879 |
+
if len_asym_id_chain >= 2:
|
| 880 |
+
for i in range(0, len_asym_id_chain - 1):
|
| 881 |
+
asym_id_i = asym_id_chain[i]
|
| 882 |
+
mask_i = asym_id == asym_id_i
|
| 883 |
+
x_gt_i = x_gt[mask_i]
|
| 884 |
+
a_mask_i = a_mask[mask_i]
|
| 885 |
+
for j in range(i + 1, len_asym_id_chain):
|
| 886 |
+
if not bool(asym_id_is_ligand[i]) and not bool(asym_id_is_ligand[j]):
|
| 887 |
+
continue
|
| 888 |
+
asym_id_j = asym_id_chain[j]
|
| 889 |
+
mask_j = asym_id == asym_id_j
|
| 890 |
+
x_gt_j = x_gt[mask_j]
|
| 891 |
+
a_mask_j = a_mask[mask_j]
|
| 892 |
+
dis_ij = torch.norm(x_gt_i[:, None, :] - x_gt_j[None, :, :], dim=-1)
|
| 893 |
+
dis_ij = dis_ij + (1 - a_mask_i[:, None] * a_mask_j[None]) * 1000
|
| 894 |
+
if torch.min(dis_ij) < self.token_bond_threshold:
|
| 895 |
+
ij = torch.argmin(dis_ij).item()
|
| 896 |
+
l_j = len(x_gt_j)
|
| 897 |
+
atom_i = int(ij // l_j) # raw
|
| 898 |
+
atom_j = int(ij % l_j) # col
|
| 899 |
+
global_atom_i = atom_i + asym_id_atom_offset[i]
|
| 900 |
+
global_atom_j = atom_j + asym_id_atom_offset[j]
|
| 901 |
+
token_i = atom_id_to_token_id[global_atom_i]
|
| 902 |
+
token_j = atom_id_to_token_id[global_atom_j]
|
| 903 |
+
|
| 904 |
+
between_conformer_token_bonds[token_i, token_j] = 1
|
| 905 |
+
between_conformer_token_bonds[token_j, token_i] = 1
|
| 906 |
+
token_bond_seed = random.random()
|
| 907 |
+
tensors["token_bonds"] = tensors["token_bonds"] + between_conformer_token_bonds
|
| 908 |
+
# Docking Indicate Token Bond
|
| 909 |
+
# if token_bond_seed >= 1:
|
| 910 |
+
# tensors["token_bonds_feature"] = tensors["token_bonds"]
|
| 911 |
+
return tensors
|
| 912 |
+
|
| 913 |
+
def _pad_to_size(self, tensors):
|
| 914 |
+
|
| 915 |
+
to_pad_atom = self.atom_crop_size - len(tensors["x_gt"])
|
| 916 |
+
to_pad_token = self.crop_size - len(tensors["residue_index"])
|
| 917 |
+
if to_pad_token > 0:
|
| 918 |
+
for k in ["restype", "residue_index", "is_protein", "is_short_poly", "is_ligand", "is_key_res",
|
| 919 |
+
"asym_id", "entity_id", "sym_id", "token_id_to_conformer_id", "s_mask",
|
| 920 |
+
"token_id_to_centre_atom_id", "token_id_to_pseudo_beta_atom_id", "token_id_to_chunk_sizes",
|
| 921 |
+
"pocket_res_feat"]:
|
| 922 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token])
|
| 923 |
+
for k in ["target_feat", "msa_feat", "key_res_feat", "batch_msa_feat"]:
|
| 924 |
+
if k in tensors:
|
| 925 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token])
|
| 926 |
+
for k in ["token_bonds", "token_bonds_feature"]:
|
| 927 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token, 0, to_pad_token])
|
| 928 |
+
for k in ["rel_tok_feat"]:
|
| 929 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token, 0, to_pad_token])
|
| 930 |
+
if to_pad_atom > 0:
|
| 931 |
+
for k in ["a_mask", "x_exists", "atom_id_to_conformer_id", "atom_id_to_token_id", "ref_space_uid"]:
|
| 932 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom])
|
| 933 |
+
for k in ["x_gt", "ref_feat", "ref_pos"]: # , "ref_pos_new"
|
| 934 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_atom])
|
| 935 |
+
# for k in ["z_mask"]: # , "ref_pos_new"
|
| 936 |
+
# tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
|
| 937 |
+
# for k in ["conformer_mask_atom"]:
|
| 938 |
+
# tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
|
| 939 |
+
# for k in ["rel_token_feat_atom"]:
|
| 940 |
+
# tensors[k] = torch.nn.functional.pad(tensors[k], [0,0,0, to_pad_atom, 0, to_pad_atom])
|
| 941 |
+
# rel_token_feat_atom
|
| 942 |
+
return tensors
|
| 943 |
+
|
| 944 |
+
def get_template_feat(self, tensors):
|
| 945 |
+
x_gt = tensors["x_gt"][tensors["token_id_to_pseudo_beta_atom_id"]]
|
| 946 |
+
z_mask = tensors["z_mask"]
|
| 947 |
+
asym_id = tensors["asym_id"]
|
| 948 |
+
is_protein = tensors["is_protein"]
|
| 949 |
+
chain_same = (asym_id[None] == asym_id[:, None]).float()
|
| 950 |
+
protein2d = is_protein[None] * is_protein[:, None]
|
| 951 |
+
dgram = dgram_from_positions(x_gt, no_bins=39)
|
| 952 |
+
dgram = dgram * protein2d[..., None] * z_mask[..., None]
|
| 953 |
+
|
| 954 |
+
if not self.inference_mode:
|
| 955 |
+
if random.random() > self.train_use_template_ratio:
|
| 956 |
+
tensors["t_mask"] = torch.tensor(1, dtype=torch.float32)
|
| 957 |
+
bert_mask = torch.rand([len(x_gt)]) > random.random() * (1 - self.train_template_mask_max_ratio)
|
| 958 |
+
template_pseudo_beta_mask = (bert_mask[None] * bert_mask[:, None]) * z_mask * protein2d
|
| 959 |
+
else:
|
| 960 |
+
tensors["t_mask"] = torch.tensor(0, dtype=torch.float32)
|
| 961 |
+
template_pseudo_beta_mask = z_mask * protein2d
|
| 962 |
+
else:
|
| 963 |
+
tensors["t_mask"] = torch.tensor(1, dtype=torch.float32)
|
| 964 |
+
template_pseudo_beta_mask = z_mask * protein2d
|
| 965 |
+
dgram = dgram * template_pseudo_beta_mask[..., None]
|
| 966 |
+
templ_feat = torch.cat([dgram, template_pseudo_beta_mask[..., None]], dim=-1)
|
| 967 |
+
tensors["templ_feat"] = templ_feat.float()
|
| 968 |
+
return tensors
|
| 969 |
+
|
| 970 |
+
def transform(self, raw_feats):
|
| 971 |
+
# np to tensor
|
| 972 |
+
tensors = dict()
|
| 973 |
+
for key in raw_feats.keys():
|
| 974 |
+
tensors[key] = torch.from_numpy(raw_feats[key])
|
| 975 |
+
# Make Target & MSA Feat
|
| 976 |
+
tensors = self.make_feats(tensors)
|
| 977 |
+
|
| 978 |
+
# Make Token Bond Feat
|
| 979 |
+
tensors = self._make_token_bonds(tensors)
|
| 980 |
+
|
| 981 |
+
# Padding
|
| 982 |
+
if not self.inference_mode:
|
| 983 |
+
tensors = self._pad_to_size(tensors)
|
| 984 |
+
|
| 985 |
+
# Mask
|
| 986 |
+
tensors["z_mask"] = tensors["s_mask"][None] * tensors["s_mask"][:, None]
|
| 987 |
+
tensors["ap_mask"] = tensors["a_mask"][None] * tensors["a_mask"][:, None]
|
| 988 |
+
tensors["is_dna"] = torch.zeros_like(tensors["is_protein"])
|
| 989 |
+
tensors["is_rna"] = torch.zeros_like(tensors["is_protein"])
|
| 990 |
+
|
| 991 |
+
# Template
|
| 992 |
+
tensors = self.get_template_feat(tensors)
|
| 993 |
+
|
| 994 |
+
# Correct Type
|
| 995 |
+
is_short_poly = tensors.pop("is_short_poly")
|
| 996 |
+
tensors["is_protein"] = tensors["is_protein"] + is_short_poly
|
| 997 |
+
tensors["is_ligand"] = tensors["is_ligand"] - is_short_poly
|
| 998 |
+
return tensors
|
| 999 |
+
|
| 1000 |
+
# residue_index 0-100
|
| 1001 |
+
# CCDS
|
| 1002 |
+
# CCD<RES_ID>
|
| 1003 |
+
#
|
| 1004 |
+
def load(
|
| 1005 |
+
self,
|
| 1006 |
+
system_pkl_path, # Receptor chains: all_atom_positions pocket_res_feat Ligand_chains
|
| 1007 |
+
template_receptor_chain_ids=None, # ["A"]
|
| 1008 |
+
template_ligand_chain_ids=None, # ["1"]
|
| 1009 |
+
remove_receptor=False,
|
| 1010 |
+
remove_ligand=False, # True, CCD_META_DATA ref_mol
|
| 1011 |
+
smi=None, # "CCCCC"
|
| 1012 |
+
):
|
| 1013 |
+
##########################################################
|
| 1014 |
+
# Initialization of Configs #
|
| 1015 |
+
##########################################################
|
| 1016 |
+
if self.inference_mode:
|
| 1017 |
+
pocket_type = self.infer_pocket_type
|
| 1018 |
+
pocket_cutoff = self.infer_pocket_cutoff
|
| 1019 |
+
pocket_dist_type = self.infer_pocket_dist_type
|
| 1020 |
+
use_pocket = self.infer_use_pocket
|
| 1021 |
+
use_key_res = self.infer_use_key_res
|
| 1022 |
+
else:
|
| 1023 |
+
pocket_type = random.choices(
|
| 1024 |
+
["atom", "ca"],
|
| 1025 |
+
[self.train_pocket_type_atom_ratio, 1 - self.train_pocket_type_atom_ratio])
|
| 1026 |
+
|
| 1027 |
+
pocket_dist_type = random.choices(
|
| 1028 |
+
["ligand", "ligand_cetre"],
|
| 1029 |
+
[self.train_pocket_dist_type_ligand_ratio, 1 - self.train_pocket_dist_type_ligand_ratio])
|
| 1030 |
+
|
| 1031 |
+
if pocket_dist_type == "ligand":
|
| 1032 |
+
pocket_cutoff = self.train_pocket_cutoff_ligand_min + random.random() * (
|
| 1033 |
+
self.train_pocket_cutoff_ligand_max - self.train_pocket_cutoff_ligand_min)
|
| 1034 |
+
else:
|
| 1035 |
+
pocket_cutoff = self.train_pocket_cutoff_ligand_centre_min + random.random() * (
|
| 1036 |
+
self.train_pocket_cutoff_ligand_centre_max - self.train_pocket_cutoff_ligand_centre_min)
|
| 1037 |
+
|
| 1038 |
+
use_pocket = random.random() < self.train_use_pocket_ratio
|
| 1039 |
+
use_key_res = random.random() < self.train_use_key_res_ratio
|
| 1040 |
+
|
| 1041 |
+
##########################################################
|
| 1042 |
+
# Initialization of features #
|
| 1043 |
+
##########################################################
|
| 1044 |
+
system_id = os.path.split(system_pkl_path)[1][:-7]
|
| 1045 |
+
all_chain_labels = {}
|
| 1046 |
+
all_chain_features = {}
|
| 1047 |
+
|
| 1048 |
+
CONF_META_DATA = {}
|
| 1049 |
+
ref_mol = None
|
| 1050 |
+
CHAIN_CLASS = {}
|
| 1051 |
+
SEQ3 = {}
|
| 1052 |
+
ASYM_ID = {}
|
| 1053 |
+
|
| 1054 |
+
##########################################################
|
| 1055 |
+
# Load All Chain Labels #
|
| 1056 |
+
##########################################################
|
| 1057 |
+
data = load_pkl(system_pkl_path)
|
| 1058 |
+
# print(data)
|
| 1059 |
+
if template_receptor_chain_ids is None:
|
| 1060 |
+
template_receptor_chain_ids = [chain_id for chain_id in data.keys() if not chain_id.isdigit()]
|
| 1061 |
+
|
| 1062 |
+
if template_ligand_chain_ids is None:
|
| 1063 |
+
template_ligand_chain_ids = [chain_id for chain_id in data.keys() if chain_id.isdigit()]
|
| 1064 |
+
# TODO: Save Ligand Centre for cropped screening
|
| 1065 |
+
# Calculate Pocket Residue According to Template receptor and ligand
|
| 1066 |
+
if not remove_receptor and len(template_ligand_chain_ids) > 0:
|
| 1067 |
+
for receptor_chain_id in template_receptor_chain_ids:
|
| 1068 |
+
ccds_receptor = data[receptor_chain_id]["ccds"]
|
| 1069 |
+
x_gt_receptor = data[receptor_chain_id]["all_atom_positions"]
|
| 1070 |
+
x_exists_receptor = data[receptor_chain_id]["all_atom_mask"]
|
| 1071 |
+
x_gt_this_receptor = []
|
| 1072 |
+
atom_id_to_ccd_id = []
|
| 1073 |
+
for ccd_id, (ccd, x_gt_ccd, x_exists_ccd) in enumerate(
|
| 1074 |
+
zip(ccds_receptor, x_gt_receptor, x_exists_receptor)):
|
| 1075 |
+
|
| 1076 |
+
if rc.is_standard(ccd):
|
| 1077 |
+
x_exists_ccd_bool = x_exists_ccd.astype(np.bool_)
|
| 1078 |
+
if x_exists_ccd_bool[1]: # CA exsits
|
| 1079 |
+
if pocket_type == "atom":
|
| 1080 |
+
num_atoms = sum(x_exists_ccd_bool)
|
| 1081 |
+
x_gt_this_receptor.append(x_gt_ccd[x_exists_ccd_bool])
|
| 1082 |
+
atom_id_to_ccd_id += num_atoms * [ccd_id]
|
| 1083 |
+
else:
|
| 1084 |
+
x_gt_this_receptor.append(x_gt_ccd[1][None])
|
| 1085 |
+
atom_id_to_ccd_id.append(ccd_id)
|
| 1086 |
+
x_gt_this_receptor = np.concatenate(x_gt_this_receptor, axis=0)
|
| 1087 |
+
atom_id_to_ccd_id = np.array(atom_id_to_ccd_id)
|
| 1088 |
+
used_ccd_ids = []
|
| 1089 |
+
for ligand_chain_id in template_ligand_chain_ids:
|
| 1090 |
+
x_gt_ligand = data[ligand_chain_id]["all_atom_positions"]
|
| 1091 |
+
x_exists_ligand = data[ligand_chain_id]["all_atom_mask"]
|
| 1092 |
+
x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)[
|
| 1093 |
+
np.concatenate(x_exists_ligand, axis=0).astype(np.bool_)]
|
| 1094 |
+
if pocket_dist_type == "ligand":
|
| 1095 |
+
used_ccd_bool = np.any(
|
| 1096 |
+
np.linalg.norm(x_gt_this_receptor[:, None] - x_gt_ligand[None], axis=-1) < pocket_cutoff,
|
| 1097 |
+
axis=-1)
|
| 1098 |
+
elif pocket_dist_type == "ligand_centre":
|
| 1099 |
+
x_mean = np.min(x_gt_ligand, axis=0, keepdims=True)
|
| 1100 |
+
used_ccd_bool = np.any(
|
| 1101 |
+
np.linalg.norm(x_gt_this_receptor[:, None] - x_mean[None], axis=-1) < pocket_cutoff,
|
| 1102 |
+
axis=-1)
|
| 1103 |
+
else:
|
| 1104 |
+
raise NotImplementedError()
|
| 1105 |
+
used_ccd_ids.append(atom_id_to_ccd_id[used_ccd_bool])
|
| 1106 |
+
used_ccd_ids = list(sorted(list(set(np.concatenate(used_ccd_ids, axis=0).tolist()))))
|
| 1107 |
+
pocket_res_feat = np.zeros([len(ccds_receptor)], dtype=np.float32)
|
| 1108 |
+
pocket_res_feat[used_ccd_ids] = 1.
|
| 1109 |
+
all_chain_labels[receptor_chain_id] = data[receptor_chain_id]
|
| 1110 |
+
all_chain_labels[receptor_chain_id]["pocket_res_feat"] = pocket_res_feat
|
| 1111 |
+
|
| 1112 |
+
if remove_ligand:
|
| 1113 |
+
if remove_receptor:
|
| 1114 |
+
assert smi is not None and self.inference_mode
|
| 1115 |
+
if smi is not None:
|
| 1116 |
+
all_chain_labels, CONF_META_DATA, ref_mol = self._update_smi(smi, all_chain_labels, CONF_META_DATA)
|
| 1117 |
+
else:
|
| 1118 |
+
assert smi is None
|
| 1119 |
+
for ligand_chain_id in template_ligand_chain_ids:
|
| 1120 |
+
all_chain_labels[ligand_chain_id] = data[ligand_chain_id]
|
| 1121 |
+
# For Benchmarking
|
| 1122 |
+
if len(template_ligand_chain_ids) == 1:
|
| 1123 |
+
ccds = all_chain_labels[template_ligand_chain_ids[0]]["ccds"]
|
| 1124 |
+
if len(ccds) == 1:
|
| 1125 |
+
# ref_mol = self.ccd_id_ref_mol[ccds[0]]
|
| 1126 |
+
ref_mol = self.ccd_id_meta_data[ccds[0]]["ref_mol"]
|
| 1127 |
+
##########################################################
|
| 1128 |
+
# Init All Chain Features #
|
| 1129 |
+
##########################################################
|
| 1130 |
+
for chain_id, chain_feature in all_chain_labels.items():
|
| 1131 |
+
ccds = chain_feature["ccds"]
|
| 1132 |
+
CONF_META_DATA = self._update_CONF_META_DATA(CONF_META_DATA, ccds)
|
| 1133 |
+
SEQ3[chain_id] = "-".join(ccds)
|
| 1134 |
+
chain_class = "protein" if not chain_id.isdigit() else "ligand"
|
| 1135 |
+
chain_feature["chain_class"] = chain_class
|
| 1136 |
+
# print(chain_id)
|
| 1137 |
+
all_chain_features[chain_id] = self._update_chain_feature(
|
| 1138 |
+
chain_feature,
|
| 1139 |
+
CONF_META_DATA,
|
| 1140 |
+
use_pocket,
|
| 1141 |
+
use_key_res,
|
| 1142 |
+
)
|
| 1143 |
+
CHAIN_CLASS[chain_id] = chain_class
|
| 1144 |
+
##########################################################
|
| 1145 |
+
# Add Assembly Feature #
|
| 1146 |
+
##########################################################
|
| 1147 |
+
all_chain_features, ASYM_ID = self._add_assembly_feature(all_chain_features, SEQ3, ASYM_ID)
|
| 1148 |
+
|
| 1149 |
+
infer_meta_data = {
|
| 1150 |
+
"CONF_META_DATA": CONF_META_DATA,
|
| 1151 |
+
"SEQ3": SEQ3,
|
| 1152 |
+
"ASYM_ID": ASYM_ID,
|
| 1153 |
+
"CHAIN_CLASS": CHAIN_CLASS,
|
| 1154 |
+
"ref_mol": ref_mol,
|
| 1155 |
+
"system_id": system_id,
|
| 1156 |
+
}
|
| 1157 |
+
##########################################################
|
| 1158 |
+
# Cropping #
|
| 1159 |
+
##########################################################
|
| 1160 |
+
# if not self.inference_mode:
|
| 1161 |
+
if self.crop_size is not None:
|
| 1162 |
+
all_chain_features, infer_meta_data = self._crop_all_chain_features(
|
| 1163 |
+
all_chain_features, infer_meta_data, crop_centre=None) # TODO: Add Cropping Centre
|
| 1164 |
+
##########################################################
|
| 1165 |
+
# Pair & Merge #
|
| 1166 |
+
##########################################################
|
| 1167 |
+
raw_feats, infer_meta_data = self.pair_and_merge(all_chain_features, infer_meta_data)
|
| 1168 |
+
|
| 1169 |
+
##########################################################
|
| 1170 |
+
# Transform #
|
| 1171 |
+
##########################################################
|
| 1172 |
+
tensors = self.transform(raw_feats)
|
| 1173 |
+
return tensors, infer_meta_data
|
| 1174 |
+
|
| 1175 |
+
def write_pdb(self, x_pred, fname, infer_meta_data, receptor_only=False, ligand_only=False):
|
| 1176 |
+
ccds = infer_meta_data["ccds"]
|
| 1177 |
+
atom_id_to_conformer_atom_id = infer_meta_data["atom_id_to_conformer_atom_id"]
|
| 1178 |
+
ccd_chunk_sizes = infer_meta_data["conformer_id_to_chunk_sizes"].tolist()
|
| 1179 |
+
CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"]
|
| 1180 |
+
conf_meta_data = infer_meta_data["CONF_META_DATA"]
|
| 1181 |
+
residue_index = infer_meta_data["residue_index"].tolist()
|
| 1182 |
+
asym_id = infer_meta_data["asym_id"].tolist()
|
| 1183 |
+
|
| 1184 |
+
atom_lines = []
|
| 1185 |
+
atom_offset = 0
|
| 1186 |
+
for ccd_id, (ccd, chunk_size, res_id) in enumerate(zip(ccds, ccd_chunk_sizes, residue_index)):
|
| 1187 |
+
inner_atom_idx = atom_id_to_conformer_atom_id[atom_offset:atom_offset + chunk_size]
|
| 1188 |
+
atom_names = [conf_meta_data[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 1189 |
+
atom_elements = [PeriodicTable[conf_meta_data[ccd]["ref_element"][i]] for i in inner_atom_idx]
|
| 1190 |
+
chain_tag = PDB_CHAIN_IDS[int(asym_id[ccd_id])]
|
| 1191 |
+
record_type = "HETATM" if CHAIN_CLASS[ccd_id] == "ligand" else "ATOM"
|
| 1192 |
+
|
| 1193 |
+
for ccd_atom_idx, atom_name in enumerate(atom_names):
|
| 1194 |
+
x = x_pred[atom_offset]
|
| 1195 |
+
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
| 1196 |
+
res_name_3 = ccd
|
| 1197 |
+
alt_loc = ""
|
| 1198 |
+
insertion_code = ""
|
| 1199 |
+
occupancy = 1.00
|
| 1200 |
+
element = atom_elements[ccd_atom_idx]
|
| 1201 |
+
# b_factor = torch.argmax(plddt[atom_offset],dim=-1).item()*2 +1
|
| 1202 |
+
b_factor = 70.
|
| 1203 |
+
charge = 0
|
| 1204 |
+
pos = x.tolist()
|
| 1205 |
+
atom_line = (
|
| 1206 |
+
f"{record_type:<6}{atom_offset + 1:>5} {name:<4}{alt_loc:>1}"
|
| 1207 |
+
f"{res_name_3.split()[0][-3:]:>3} {chain_tag:>1}"
|
| 1208 |
+
f"{res_id + 1:>4}{insertion_code:>1} "
|
| 1209 |
+
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
| 1210 |
+
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
| 1211 |
+
f"{element:>2}{charge:>2}"
|
| 1212 |
+
)
|
| 1213 |
+
if receptor_only and not ligand_only:
|
| 1214 |
+
if record_type == "ATOM":
|
| 1215 |
+
atom_lines.append(atom_line)
|
| 1216 |
+
elif not receptor_only and ligand_only:
|
| 1217 |
+
if record_type == "HETATM":
|
| 1218 |
+
atom_lines.append(atom_line)
|
| 1219 |
+
elif not receptor_only and not ligand_only:
|
| 1220 |
+
atom_lines.append(atom_line)
|
| 1221 |
+
else:
|
| 1222 |
+
raise NotImplementedError()
|
| 1223 |
+
atom_offset += 1
|
| 1224 |
+
if atom_offset == len(atom_id_to_conformer_atom_id):
|
| 1225 |
+
break
|
| 1226 |
+
out = "\n".join(atom_lines)
|
| 1227 |
+
out = f"MODEL 1\n{out}\nTER\nENDMDL\nEND"
|
| 1228 |
+
dump_txt(out, fname)
|
| 1229 |
+
|
| 1230 |
+
def write_pdb_block(self, x_pred, infer_meta_data, receptor_only=False, ligand_only=False):
|
| 1231 |
+
ccds = infer_meta_data["ccds"]
|
| 1232 |
+
atom_id_to_conformer_atom_id = infer_meta_data["atom_id_to_conformer_atom_id"]
|
| 1233 |
+
ccd_chunk_sizes = infer_meta_data["conformer_id_to_chunk_sizes"].tolist()
|
| 1234 |
+
CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"]
|
| 1235 |
+
conf_meta_data = infer_meta_data["CONF_META_DATA"]
|
| 1236 |
+
residue_index = infer_meta_data["residue_index"].tolist()
|
| 1237 |
+
asym_id = infer_meta_data["asym_id"].tolist()
|
| 1238 |
+
|
| 1239 |
+
atom_lines = []
|
| 1240 |
+
atom_offset = 0
|
| 1241 |
+
for ccd_id, (ccd, chunk_size, res_id) in enumerate(zip(ccds, ccd_chunk_sizes, residue_index)):
|
| 1242 |
+
inner_atom_idx = atom_id_to_conformer_atom_id[atom_offset:atom_offset + chunk_size]
|
| 1243 |
+
atom_names = [conf_meta_data[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 1244 |
+
atom_elements = [PeriodicTable[conf_meta_data[ccd]["ref_element"][i]] for i in inner_atom_idx]
|
| 1245 |
+
chain_tag = PDB_CHAIN_IDS[int(asym_id[ccd_id])]
|
| 1246 |
+
record_type = "HETATM" if CHAIN_CLASS[ccd_id] == "ligand" else "ATOM"
|
| 1247 |
+
|
| 1248 |
+
for ccd_atom_idx, atom_name in enumerate(atom_names):
|
| 1249 |
+
x = x_pred[atom_offset]
|
| 1250 |
+
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
| 1251 |
+
res_name_3 = ccd
|
| 1252 |
+
alt_loc = ""
|
| 1253 |
+
insertion_code = ""
|
| 1254 |
+
occupancy = 1.00
|
| 1255 |
+
element = atom_elements[ccd_atom_idx]
|
| 1256 |
+
# b_factor = torch.argmax(plddt[atom_offset],dim=-1).item()*2 +1
|
| 1257 |
+
b_factor = 70.
|
| 1258 |
+
charge = 0
|
| 1259 |
+
pos = x.tolist()
|
| 1260 |
+
atom_line = (
|
| 1261 |
+
f"{record_type:<6}{atom_offset + 1:>5} {name:<4}{alt_loc:>1}"
|
| 1262 |
+
f"{res_name_3.split()[0][-3:]:>3} {chain_tag:>1}"
|
| 1263 |
+
f"{res_id + 1:>4}{insertion_code:>1} "
|
| 1264 |
+
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
| 1265 |
+
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
| 1266 |
+
f"{element:>2}{charge:>2}"
|
| 1267 |
+
)
|
| 1268 |
+
if receptor_only and not ligand_only:
|
| 1269 |
+
if record_type == "ATOM":
|
| 1270 |
+
atom_lines.append(atom_line)
|
| 1271 |
+
elif not receptor_only and ligand_only:
|
| 1272 |
+
if record_type == "HETATM":
|
| 1273 |
+
atom_lines.append(atom_line)
|
| 1274 |
+
elif not receptor_only and not ligand_only:
|
| 1275 |
+
atom_lines.append(atom_line)
|
| 1276 |
+
else:
|
| 1277 |
+
raise NotImplementedError()
|
| 1278 |
+
atom_offset += 1
|
| 1279 |
+
if atom_offset == len(atom_id_to_conformer_atom_id):
|
| 1280 |
+
break
|
| 1281 |
+
out = "\n".join(atom_lines)
|
| 1282 |
+
out = f"MODEL 1\n{out}\nTER\nENDMDL\nEND"
|
| 1283 |
+
return out
|
PhysDock/data/feature_loader_plinder.py
ADDED
|
@@ -0,0 +1,1258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
############################
|
| 2 |
+
# 0.85 Receptor+Ligand
|
| 3 |
+
# 0.5 APO Template
|
| 4 |
+
# 0.5 HOLO Template
|
| 5 |
+
# 0.05 Protein (All APO or PRED)
|
| 6 |
+
# 0.1 Ligand
|
| 7 |
+
############################
|
| 8 |
+
import copy
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
+
from functools import reduce
|
| 12 |
+
from operator import add
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
# Key Res
|
| 17 |
+
# Dynamic Cutoff
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from stdock.data.constants.PDBData import protein_letters_3to1_extended
|
| 22 |
+
from stdock.data.constants import restype_constants as rc
|
| 23 |
+
from stdock.utils.io_utils import convert_md5_string, load_json, load_pkl, dump_txt
|
| 24 |
+
from stdock.data.tools.feature_processing_multimer import pair_and_merge
|
| 25 |
+
from stdock.utils.tensor_utils import centre_random_augmentation_np_apply, dgram_from_positions, \
|
| 26 |
+
centre_random_augmentation_np_batch
|
| 27 |
+
from stdock.data.constants.periodic_table import PeriodicTable
|
| 28 |
+
|
| 29 |
+
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FeatureLoader:
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
# config,
|
| 36 |
+
token_crop_size=256,
|
| 37 |
+
atom_crop_size=256 * 8,
|
| 38 |
+
inference_mode=False,
|
| 39 |
+
):
|
| 40 |
+
self.inference_mode = inference_mode
|
| 41 |
+
self.msa_features_path = "/2022133002/data/stfold-data-v5/features/msa_features/"
|
| 42 |
+
self.uniprot_msa_features_path = "/2022133002/data/stfold-data-v5/features/uniprot_msa_features/"
|
| 43 |
+
|
| 44 |
+
self.token_crop_size = token_crop_size
|
| 45 |
+
self.atom_crop_size = atom_crop_size
|
| 46 |
+
self.token_bond_threshold = 2.4
|
| 47 |
+
|
| 48 |
+
self.ccd_id_meta_data = load_pkl(
|
| 49 |
+
"/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_meta_data_confs_chars.pkl.gz")
|
| 50 |
+
|
| 51 |
+
self.samples_path = "/2022133002/data/plinder/2024-06/v2/plinder_samples_raw_data_v2"
|
| 52 |
+
to_remove = set(load_json("/2022133002/projects/stdock/stdock_v9.6/scripts/to_remove.json"))
|
| 53 |
+
weights = load_json(
|
| 54 |
+
"/2022133002/projects/stdock/stdock_v9.5/scripts/cluster_scripts/train_samples_new_weight_seq.json")
|
| 55 |
+
self.splits = load_json("/2022133002/data/plinder/2024-06/v2/splits/splits.json")
|
| 56 |
+
self.used_sample_ids = [sample_id for sample_id in weights
|
| 57 |
+
if sample_id in self.splits and self.splits[sample_id] != "test"
|
| 58 |
+
and sample_id not in to_remove]
|
| 59 |
+
print("train samples", len(self.used_sample_ids))
|
| 60 |
+
# self.weights = np.array(list(weights.values()))
|
| 61 |
+
self.weights = np.array([weights[sample_id] for sample_id in self.used_sample_ids])
|
| 62 |
+
|
| 63 |
+
self.probabilities = torch.from_numpy(self.weights / self.weights.sum())
|
| 64 |
+
|
| 65 |
+
self.used_test_sample_ids = [sample_id for sample_id in weights
|
| 66 |
+
if sample_id in self.splits and self.splits[sample_id] == "test"
|
| 67 |
+
and sample_id not in to_remove]
|
| 68 |
+
print("test samples", len(self.used_test_sample_ids))
|
| 69 |
+
# self.weights = np.array(list(weights.values()))
|
| 70 |
+
self.test_weights = np.array([weights[sample_id] for sample_id in self.used_test_sample_ids])
|
| 71 |
+
|
| 72 |
+
self.test_probabilities = torch.from_numpy(self.test_weights / self.test_weights.sum())
|
| 73 |
+
|
| 74 |
+
def _update_CONF_META_DATA(self, CONF_META_DATA, ccds):
|
| 75 |
+
for ccd in ccds:
|
| 76 |
+
if ccd in CONF_META_DATA:
|
| 77 |
+
continue
|
| 78 |
+
ccd_features = self.ccd_id_meta_data[ccd]
|
| 79 |
+
ref_pos = ccd_features["ref_pos"]
|
| 80 |
+
|
| 81 |
+
ref_pos = ref_pos - np.mean(ref_pos, axis=0, keepdims=True)
|
| 82 |
+
|
| 83 |
+
CONF_META_DATA[ccd] = {
|
| 84 |
+
"ref_feat": np.concatenate([
|
| 85 |
+
ref_pos,
|
| 86 |
+
ccd_features["ref_charge"][..., None],
|
| 87 |
+
rc.eye_128[ccd_features["ref_element"]].astype(np.float32),
|
| 88 |
+
ccd_features["ref_is_aromatic"].astype(np.float32)[..., None],
|
| 89 |
+
rc.eye_9[ccd_features["ref_degree"]].astype(np.float32),
|
| 90 |
+
rc.eye_7[ccd_features["ref_hybridization"]].astype(np.float32),
|
| 91 |
+
rc.eye_9[ccd_features["ref_implicit_valence"]].astype(np.float32),
|
| 92 |
+
rc.eye_3[ccd_features["ref_chirality"]].astype(np.float32),
|
| 93 |
+
ccd_features["ref_in_ring_of_3"].astype(np.float32)[..., None],
|
| 94 |
+
ccd_features["ref_in_ring_of_4"].astype(np.float32)[..., None],
|
| 95 |
+
ccd_features["ref_in_ring_of_5"].astype(np.float32)[..., None],
|
| 96 |
+
ccd_features["ref_in_ring_of_6"].astype(np.float32)[..., None],
|
| 97 |
+
ccd_features["ref_in_ring_of_7"].astype(np.float32)[..., None],
|
| 98 |
+
ccd_features["ref_in_ring_of_8"].astype(np.float32)[..., None],
|
| 99 |
+
], axis=-1),
|
| 100 |
+
"rel_tok_feat": np.concatenate([
|
| 101 |
+
rc.eye_32[ccd_features["d_token"]].astype(np.float32),
|
| 102 |
+
rc.eye_5[ccd_features["bond_type"]].astype(np.float32),
|
| 103 |
+
ccd_features["token_bonds"].astype(np.float32)[..., None],
|
| 104 |
+
ccd_features["bond_as_double"].astype(np.float32)[..., None],
|
| 105 |
+
ccd_features["bond_in_ring"].astype(np.float32)[..., None],
|
| 106 |
+
ccd_features["bond_is_conjugated"].astype(np.float32)[..., None],
|
| 107 |
+
ccd_features["bond_is_aromatic"].astype(np.float32)[..., None],
|
| 108 |
+
], axis=-1),
|
| 109 |
+
"ref_atom_name_chars": ccd_features["ref_atom_name_chars"],
|
| 110 |
+
"ref_element": ccd_features["ref_element"],
|
| 111 |
+
"token_bonds": ccd_features["token_bonds"],
|
| 112 |
+
|
| 113 |
+
}
|
| 114 |
+
if not rc.is_standard(ccd):
|
| 115 |
+
conformers = ccd_features["conformers"]
|
| 116 |
+
if conformers is None:
|
| 117 |
+
conformers = np.repeat(ccd_features["ref_pos"][None], 32, axis=0)
|
| 118 |
+
else:
|
| 119 |
+
conformers = np.stack([random.choice(ccd_features["conformers"]) for i in range(32)], axis=0)
|
| 120 |
+
CONF_META_DATA[ccd]["batch_ref_pos"] = centre_random_augmentation_np_batch(conformers)
|
| 121 |
+
|
| 122 |
+
return CONF_META_DATA
|
| 123 |
+
|
| 124 |
+
def _update_CONF_META_DATA_ligand(self, CONF_META_DATA, sequence_3, ccd_features):
|
| 125 |
+
ccds = sequence_3.split("-")
|
| 126 |
+
# ccd_features = self.ccd_id_meta_data[ccd]
|
| 127 |
+
for ccd in ccds:
|
| 128 |
+
CONF_META_DATA[ccd] = {
|
| 129 |
+
"ref_feat": np.concatenate([
|
| 130 |
+
ccd_features["ref_pos"],
|
| 131 |
+
ccd_features["ref_charge"][..., None],
|
| 132 |
+
rc.eye_128[ccd_features["ref_element"]].astype(np.float32),
|
| 133 |
+
ccd_features["ref_is_aromatic"].astype(np.float32)[..., None],
|
| 134 |
+
rc.eye_9[ccd_features["ref_degree"]].astype(np.float32),
|
| 135 |
+
rc.eye_7[ccd_features["ref_hybridization"]].astype(np.float32),
|
| 136 |
+
rc.eye_9[ccd_features["ref_implicit_valence"]].astype(np.float32),
|
| 137 |
+
rc.eye_3[ccd_features["ref_chirality"]].astype(np.float32),
|
| 138 |
+
ccd_features["ref_in_ring_of_3"].astype(np.float32)[..., None],
|
| 139 |
+
ccd_features["ref_in_ring_of_4"].astype(np.float32)[..., None],
|
| 140 |
+
ccd_features["ref_in_ring_of_5"].astype(np.float32)[..., None],
|
| 141 |
+
ccd_features["ref_in_ring_of_6"].astype(np.float32)[..., None],
|
| 142 |
+
ccd_features["ref_in_ring_of_7"].astype(np.float32)[..., None],
|
| 143 |
+
ccd_features["ref_in_ring_of_8"].astype(np.float32)[..., None],
|
| 144 |
+
], axis=-1),
|
| 145 |
+
"rel_tok_feat": np.concatenate([
|
| 146 |
+
rc.eye_32[ccd_features["d_token"]].astype(np.float32),
|
| 147 |
+
rc.eye_5[ccd_features["bond_type"]].astype(np.float32),
|
| 148 |
+
ccd_features["token_bonds"].astype(np.float32)[..., None],
|
| 149 |
+
ccd_features["bond_as_double"].astype(np.float32)[..., None],
|
| 150 |
+
ccd_features["bond_in_ring"].astype(np.float32)[..., None],
|
| 151 |
+
ccd_features["bond_is_conjugated"].astype(np.float32)[..., None],
|
| 152 |
+
ccd_features["bond_is_aromatic"].astype(np.float32)[..., None],
|
| 153 |
+
], axis=-1),
|
| 154 |
+
"ref_atom_name_chars": ccd_features["ref_atom_name_chars"],
|
| 155 |
+
"ref_element": ccd_features["ref_element"],
|
| 156 |
+
"token_bonds": ccd_features["token_bonds"],
|
| 157 |
+
|
| 158 |
+
}
|
| 159 |
+
if not rc.is_standard(ccd):
|
| 160 |
+
conformers = ccd_features["conformers"]
|
| 161 |
+
if conformers is None:
|
| 162 |
+
conformers = np.repeat(ccd_features["ref_pos"][None], 32, axis=0)
|
| 163 |
+
else:
|
| 164 |
+
conformers = np.stack([random.choice(ccd_features["conformers"]) for i in range(32)], axis=0)
|
| 165 |
+
CONF_META_DATA[ccd]["batch_ref_pos"] = centre_random_augmentation_np_batch(conformers)
|
| 166 |
+
|
| 167 |
+
return CONF_META_DATA
|
| 168 |
+
|
| 169 |
+
def _update_chain_feature(self, chain_feature, CONF_META_DATA):
|
| 170 |
+
|
| 171 |
+
ccds_ori = chain_feature["ccds"]
|
| 172 |
+
chain_class = chain_feature["chain_class"]
|
| 173 |
+
if chain_class == "protein":
|
| 174 |
+
|
| 175 |
+
sequence = "".join([protein_letters_3to1_extended.get(ccd, "X") for ccd in ccds_ori])
|
| 176 |
+
md5 = convert_md5_string(f"protein:{sequence}")
|
| 177 |
+
|
| 178 |
+
chain_feature.update(
|
| 179 |
+
load_pkl(os.path.join(self.msa_features_path, f"{md5}.pkl.gz"))
|
| 180 |
+
)
|
| 181 |
+
chain_feature.update(
|
| 182 |
+
load_pkl(os.path.join(self.uniprot_msa_features_path, f"{md5}.pkl.gz"))
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
chain_feature["msa"] = np.array([[rc.standard_ccds.index(ccd)
|
| 186 |
+
if ccd in rc.standard_ccds else 20 for ccd in ccds_ori]] * 2,
|
| 187 |
+
dtype=np.int8)
|
| 188 |
+
chain_feature["deletion_matrix"] = np.zeros_like(chain_feature["msa"])
|
| 189 |
+
|
| 190 |
+
# Merge Key Res Feat & Augmentation
|
| 191 |
+
if "salt bridges" in chain_feature:
|
| 192 |
+
key_res_feat = np.stack([
|
| 193 |
+
chain_feature["salt bridges"],
|
| 194 |
+
chain_feature["pi-cation interactions"],
|
| 195 |
+
chain_feature["hydrophobic interactions"],
|
| 196 |
+
chain_feature["pi-stacking"],
|
| 197 |
+
chain_feature["hydrogen bonds"],
|
| 198 |
+
chain_feature["metal complexes"],
|
| 199 |
+
np.zeros_like(chain_feature["salt bridges"]),
|
| 200 |
+
], axis=-1).astype(np.float32)
|
| 201 |
+
else:
|
| 202 |
+
key_res_feat = np.zeros(
|
| 203 |
+
[len(ccds_ori), 7], dtype=np.float32
|
| 204 |
+
)
|
| 205 |
+
is_key_res = np.any(key_res_feat.astype(np.bool_), axis=-1).astype(np.float32)
|
| 206 |
+
# Augmentation
|
| 207 |
+
if not self.inference_mode:
|
| 208 |
+
key_res_feat = key_res_feat * (np.random.random([len(ccds_ori), 7]) < 0.5)
|
| 209 |
+
# else:
|
| 210 |
+
# # TODO: No key res in inference mode
|
| 211 |
+
# key_res_feat = key_res_feat * 0
|
| 212 |
+
# Atom
|
| 213 |
+
x_gt = []
|
| 214 |
+
atom_id_to_conformer_atom_id = []
|
| 215 |
+
|
| 216 |
+
# Conformer
|
| 217 |
+
conformer_id_to_chunk_sizes = []
|
| 218 |
+
residue_index = []
|
| 219 |
+
restype = []
|
| 220 |
+
ccds = []
|
| 221 |
+
|
| 222 |
+
conformer_exists = []
|
| 223 |
+
|
| 224 |
+
for c_id, ccd in enumerate(chain_feature["ccds"]):
|
| 225 |
+
no_atom_this_conf = len(CONF_META_DATA[ccd]["ref_feat"])
|
| 226 |
+
conformer_atom_ids_this_conf = np.arange(no_atom_this_conf)
|
| 227 |
+
x_gt_this_conf = chain_feature["all_atom_positions"][c_id]
|
| 228 |
+
x_exists_this_conf = chain_feature["all_atom_mask"][c_id].astype(np.bool_)
|
| 229 |
+
|
| 230 |
+
# TODO DEBUG
|
| 231 |
+
#
|
| 232 |
+
conformer_exist = np.any(x_exists_this_conf).item()
|
| 233 |
+
if rc.is_standard(ccd):
|
| 234 |
+
conformer_exist = np.sum(x_exists_this_conf).item() > len(x_exists_this_conf) - 2
|
| 235 |
+
# conformer_exist = conformer_exist and x_exists_this_conf[1]
|
| 236 |
+
# if ccd != "GLY":
|
| 237 |
+
# conformer_exist = conformer_exist and x_exists_this_conf[4]
|
| 238 |
+
|
| 239 |
+
conformer_exists.append(conformer_exist)
|
| 240 |
+
if conformer_exist:
|
| 241 |
+
# Atomwise
|
| 242 |
+
x_gt.append(x_gt_this_conf[x_exists_this_conf])
|
| 243 |
+
atom_id_to_conformer_atom_id.append(conformer_atom_ids_this_conf[x_exists_this_conf])
|
| 244 |
+
# Tokenwise
|
| 245 |
+
residue_index.append(c_id)
|
| 246 |
+
conformer_id_to_chunk_sizes.append(np.sum(x_exists_this_conf).item())
|
| 247 |
+
restype.append(rc.standard_ccds.index(ccd) if ccd in rc.standard_ccds else 20)
|
| 248 |
+
ccds.append(ccd)
|
| 249 |
+
x_gt = np.concatenate(x_gt, axis=0)
|
| 250 |
+
atom_id_to_conformer_atom_id = np.concatenate(atom_id_to_conformer_atom_id, axis=0, dtype=np.int32)
|
| 251 |
+
residue_index = np.array(residue_index, dtype=np.int64)
|
| 252 |
+
conformer_id_to_chunk_sizes = np.array(conformer_id_to_chunk_sizes, dtype=np.int64)
|
| 253 |
+
restype = np.array(restype, dtype=np.int64)
|
| 254 |
+
|
| 255 |
+
conformer_exists = np.array(conformer_exists, dtype=np.bool_)
|
| 256 |
+
|
| 257 |
+
chain_feature_update = {
|
| 258 |
+
"x_gt": x_gt,
|
| 259 |
+
"atom_id_to_conformer_atom_id": atom_id_to_conformer_atom_id,
|
| 260 |
+
"residue_index": residue_index,
|
| 261 |
+
"conformer_id_to_chunk_sizes": conformer_id_to_chunk_sizes,
|
| 262 |
+
"restype": restype,
|
| 263 |
+
"ccds": ccds,
|
| 264 |
+
"msa": chain_feature["msa"][:, conformer_exists],
|
| 265 |
+
"deletion_matrix": chain_feature["deletion_matrix"][:, conformer_exists],
|
| 266 |
+
"chain_class": chain_class,
|
| 267 |
+
"key_res_feat": key_res_feat[conformer_exists],
|
| 268 |
+
"is_key_res": is_key_res[conformer_exists],
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
chain_feature_update["is_protein"] = np.array([chain_class == "protein"] * len(ccds)).astype(np.float32)
|
| 272 |
+
chain_feature_update["is_ligand"] = np.array([chain_class != "protein"] * len(ccds)).astype(np.float32)
|
| 273 |
+
# Assert Short Poly Chain like peptide
|
| 274 |
+
chain_feature_update["is_short_poly"] = np.array(
|
| 275 |
+
[chain_class != "protein" and len(ccds) >= 2 and rc.is_standard(ccd) for ccd in ccds]
|
| 276 |
+
).astype(np.float32)
|
| 277 |
+
|
| 278 |
+
if "msa_all_seq" in chain_feature:
|
| 279 |
+
chain_feature_update["msa_all_seq"] = chain_feature["msa_all_seq"][:, conformer_exists]
|
| 280 |
+
chain_feature_update["deletion_matrix_all_seq"] = \
|
| 281 |
+
chain_feature["deletion_matrix_all_seq"][:, conformer_exists]
|
| 282 |
+
chain_feature_update["msa_species_identifiers_all_seq"] = chain_feature["msa_species_identifiers_all_seq"]
|
| 283 |
+
del chain_feature
|
| 284 |
+
return chain_feature_update
|
| 285 |
+
|
| 286 |
+
def _add_assembly_feature(self, all_chain_features, SEQ3):
|
| 287 |
+
entities = {}
|
| 288 |
+
for chain_id, seq3 in SEQ3.items():
|
| 289 |
+
if seq3 not in entities:
|
| 290 |
+
entities[seq3] = [chain_id]
|
| 291 |
+
else:
|
| 292 |
+
entities[seq3].append(chain_id)
|
| 293 |
+
|
| 294 |
+
asym_id = 0
|
| 295 |
+
ASYM_ID = {}
|
| 296 |
+
for entity_id, seq3 in enumerate(list(entities.keys())):
|
| 297 |
+
chain_ids = copy.deepcopy(entities[seq3])
|
| 298 |
+
if not self.inference_mode:
|
| 299 |
+
# sym_id augmentation
|
| 300 |
+
random.shuffle(chain_ids)
|
| 301 |
+
for sym_id, chain_id in enumerate(chain_ids):
|
| 302 |
+
num_conformers = len(all_chain_features[chain_id]["ccds"])
|
| 303 |
+
all_chain_features[chain_id]["asym_id"] = \
|
| 304 |
+
np.full([num_conformers], fill_value=asym_id, dtype=np.int32)
|
| 305 |
+
all_chain_features[chain_id]["sym_id"] = \
|
| 306 |
+
np.full([num_conformers], fill_value=sym_id, dtype=np.int32)
|
| 307 |
+
all_chain_features[chain_id]["entity_id"] = \
|
| 308 |
+
np.full([num_conformers], fill_value=entity_id, dtype=np.int32)
|
| 309 |
+
|
| 310 |
+
all_chain_features[chain_id]["sequence_3"] = seq3
|
| 311 |
+
ASYM_ID[asym_id] = chain_id
|
| 312 |
+
|
| 313 |
+
asym_id += 1
|
| 314 |
+
return all_chain_features, ASYM_ID
|
| 315 |
+
|
| 316 |
+
def load_all_chain_features(self, all_chain_labels):
|
| 317 |
+
all_chain_features = {}
|
| 318 |
+
CONF_META_DATA = {}
|
| 319 |
+
SEQ3 = {}
|
| 320 |
+
CHAIN_CLASS = {}
|
| 321 |
+
|
| 322 |
+
for chain_id, chain_feature in all_chain_labels.items():
|
| 323 |
+
ccds = chain_feature["ccds"]
|
| 324 |
+
CONF_META_DATA = self._update_CONF_META_DATA(CONF_META_DATA, ccds)
|
| 325 |
+
SEQ3[chain_id] = "-".join(ccds)
|
| 326 |
+
chain_class = "protein" if not chain_id.isdigit() else "ligand"
|
| 327 |
+
chain_feature["chain_class"] = chain_class
|
| 328 |
+
|
| 329 |
+
all_chain_features[chain_id] = self._update_chain_feature(
|
| 330 |
+
chain_feature,
|
| 331 |
+
CONF_META_DATA
|
| 332 |
+
)
|
| 333 |
+
CHAIN_CLASS[chain_id] = chain_class
|
| 334 |
+
|
| 335 |
+
all_chain_features, ASYM_ID = self._add_assembly_feature(all_chain_features, SEQ3)
|
| 336 |
+
|
| 337 |
+
infer_meta_data = {
|
| 338 |
+
"CONF_META_DATA": CONF_META_DATA,
|
| 339 |
+
"SEQ3": SEQ3,
|
| 340 |
+
"ASYM_ID": ASYM_ID,
|
| 341 |
+
"CHAIN_CLASS": CHAIN_CLASS
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
return all_chain_features, infer_meta_data
|
| 345 |
+
|
| 346 |
+
def _spatial_crop_v2(self, all_chain_features, infer_meta_data):
|
| 347 |
+
CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
|
| 348 |
+
ordered_chain_ids = list(all_chain_features.keys())
|
| 349 |
+
|
| 350 |
+
x_gt = np.concatenate([all_chain_features[chain_id]["x_gt"] for chain_id in ordered_chain_ids], axis=0)
|
| 351 |
+
# asym_id = np.concatenate([all_chain_features[chain_id]["asym_id"] for chain_id in ordered_chain_ids], axis=0)
|
| 352 |
+
|
| 353 |
+
token_id_to_centre_atom_id = []
|
| 354 |
+
token_id_to_conformer_id = []
|
| 355 |
+
token_id_to_ccd_chunk_sizes = []
|
| 356 |
+
token_id_to_ccd = []
|
| 357 |
+
asym_id_ca = []
|
| 358 |
+
token_id = 0
|
| 359 |
+
atom_id = 0
|
| 360 |
+
conf_id = 0
|
| 361 |
+
x_gt_ligand = []
|
| 362 |
+
for chain_id in ordered_chain_ids:
|
| 363 |
+
if chain_id.isdigit() and len(all_chain_features[chain_id]["ccds"]) == 1:
|
| 364 |
+
x_gt_ligand.append(all_chain_features[chain_id]["x_gt"])
|
| 365 |
+
atom_offset = 0
|
| 366 |
+
for ccd, chunk_size_this_ccd, asym_id in zip(
|
| 367 |
+
all_chain_features[chain_id]["ccds"],
|
| 368 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
|
| 369 |
+
all_chain_features[chain_id]["asym_id"],
|
| 370 |
+
):
|
| 371 |
+
inner_atom_idx = all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][
|
| 372 |
+
atom_offset:atom_offset + chunk_size_this_ccd]
|
| 373 |
+
atom_names = [CONF_META_DATA[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 374 |
+
if rc.is_standard(ccd):
|
| 375 |
+
|
| 376 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 377 |
+
if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
|
| 378 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 379 |
+
token_id_to_conformer_id.append(conf_id)
|
| 380 |
+
token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
|
| 381 |
+
token_id_to_ccd.append(ccd)
|
| 382 |
+
asym_id_ca.append(asym_id)
|
| 383 |
+
atom_id += 1
|
| 384 |
+
token_id += 1
|
| 385 |
+
|
| 386 |
+
else:
|
| 387 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 388 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 389 |
+
token_id_to_conformer_id.append(conf_id)
|
| 390 |
+
token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
|
| 391 |
+
token_id_to_ccd.append(ccd)
|
| 392 |
+
asym_id_ca.append(asym_id)
|
| 393 |
+
atom_id += 1
|
| 394 |
+
token_id += 1
|
| 395 |
+
atom_offset += chunk_size_this_ccd
|
| 396 |
+
conf_id += 1
|
| 397 |
+
|
| 398 |
+
x_gt_ca = x_gt[token_id_to_centre_atom_id]
|
| 399 |
+
asym_id_ca = np.array(asym_id_ca)
|
| 400 |
+
|
| 401 |
+
crop_scheme_seed = random.random()
|
| 402 |
+
# Spatial Crop Ligand
|
| 403 |
+
|
| 404 |
+
if crop_scheme_seed < (0.6 if not self.inference_mode else 1.0) and len(x_gt_ligand) > 0:
|
| 405 |
+
x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)
|
| 406 |
+
x_gt_sel = random.choice(x_gt_ligand)[None]
|
| 407 |
+
# Spatial Crop Interface
|
| 408 |
+
elif crop_scheme_seed < 0.8 and len(set(asym_id_ca.tolist())) > 1:
|
| 409 |
+
chain_same = asym_id_ca[None] * asym_id_ca[:, None]
|
| 410 |
+
dist = np.linalg.norm(x_gt_ca[:, None] - x_gt_ca[None], axis=-1)
|
| 411 |
+
|
| 412 |
+
dist = dist + chain_same * 100
|
| 413 |
+
# interface_threshold
|
| 414 |
+
mask = np.any(dist < 15, axis=-1)
|
| 415 |
+
if sum(mask) > 0:
|
| 416 |
+
x_gt_sel = random.choice(x_gt_ca[mask])[None]
|
| 417 |
+
else:
|
| 418 |
+
x_gt_sel = random.choice(x_gt_ca)[None]
|
| 419 |
+
# Spatial Crop
|
| 420 |
+
else:
|
| 421 |
+
x_gt_sel = random.choice(x_gt_ca)[None]
|
| 422 |
+
dist = np.linalg.norm(x_gt_ca - x_gt_sel, axis=-1)
|
| 423 |
+
token_idxs = np.argsort(dist)
|
| 424 |
+
|
| 425 |
+
select_ccds_idx = []
|
| 426 |
+
to_sum_atom = 0
|
| 427 |
+
to_sum_token = 0
|
| 428 |
+
for token_idx in token_idxs:
|
| 429 |
+
ccd_idx = token_id_to_conformer_id[token_idx]
|
| 430 |
+
ccd_chunk_size = token_id_to_ccd_chunk_sizes[token_idx]
|
| 431 |
+
ccd_this_token = token_id_to_ccd[token_idx]
|
| 432 |
+
if ccd_idx in select_ccds_idx:
|
| 433 |
+
continue
|
| 434 |
+
if to_sum_atom + ccd_chunk_size > self.atom_crop_size:
|
| 435 |
+
break
|
| 436 |
+
to_add_token = 1 if rc.is_standard(ccd_this_token) else ccd_chunk_size
|
| 437 |
+
if to_sum_token + to_add_token > self.token_crop_size:
|
| 438 |
+
break
|
| 439 |
+
select_ccds_idx.append(ccd_idx)
|
| 440 |
+
to_sum_atom += ccd_chunk_size
|
| 441 |
+
to_sum_token += to_add_token
|
| 442 |
+
|
| 443 |
+
ccd_all_id = 0
|
| 444 |
+
crop_chains = []
|
| 445 |
+
for chain_id in ordered_chain_ids:
|
| 446 |
+
conformer_used_mask = []
|
| 447 |
+
atom_used_mask = []
|
| 448 |
+
ccds = []
|
| 449 |
+
for ccd, chunk_size_this_ccd in zip(
|
| 450 |
+
all_chain_features[chain_id]["ccds"],
|
| 451 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
|
| 452 |
+
):
|
| 453 |
+
if ccd_all_id in select_ccds_idx:
|
| 454 |
+
ccds.append(ccd)
|
| 455 |
+
if chain_id not in crop_chains:
|
| 456 |
+
crop_chains.append(chain_id)
|
| 457 |
+
conformer_used_mask.append(ccd_all_id in select_ccds_idx)
|
| 458 |
+
atom_used_mask += [ccd_all_id in select_ccds_idx] * chunk_size_this_ccd
|
| 459 |
+
ccd_all_id += 1
|
| 460 |
+
conf_mask = np.array(conformer_used_mask).astype(np.bool_)
|
| 461 |
+
atom_mask = np.array(atom_used_mask).astype(np.bool_)
|
| 462 |
+
# Update All Chain Features
|
| 463 |
+
all_chain_features[chain_id]["x_gt"] = all_chain_features[chain_id]["x_gt"][atom_mask]
|
| 464 |
+
all_chain_features[chain_id]["atom_id_to_conformer_atom_id"] = \
|
| 465 |
+
all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][atom_mask]
|
| 466 |
+
all_chain_features[chain_id]["restype"] = all_chain_features[chain_id]["restype"][conf_mask]
|
| 467 |
+
all_chain_features[chain_id]["residue_index"] = all_chain_features[chain_id]["residue_index"][conf_mask]
|
| 468 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = \
|
| 469 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"][conf_mask]
|
| 470 |
+
# BUG Fix
|
| 471 |
+
all_chain_features[chain_id]["key_res_feat"] = all_chain_features[chain_id]["key_res_feat"][conf_mask]
|
| 472 |
+
all_chain_features[chain_id]["is_key_res"] = all_chain_features[chain_id]["is_key_res"][conf_mask]
|
| 473 |
+
all_chain_features[chain_id]["is_protein"] = all_chain_features[chain_id]["is_protein"][conf_mask]
|
| 474 |
+
all_chain_features[chain_id]["is_short_poly"] = all_chain_features[chain_id]["is_short_poly"][conf_mask]
|
| 475 |
+
all_chain_features[chain_id]["is_ligand"] = all_chain_features[chain_id]["is_ligand"][conf_mask]
|
| 476 |
+
all_chain_features[chain_id]["asym_id"] = all_chain_features[chain_id]["asym_id"][conf_mask]
|
| 477 |
+
all_chain_features[chain_id]["sym_id"] = all_chain_features[chain_id]["sym_id"][conf_mask]
|
| 478 |
+
all_chain_features[chain_id]["entity_id"] = all_chain_features[chain_id]["entity_id"][conf_mask]
|
| 479 |
+
|
| 480 |
+
all_chain_features[chain_id]["ccds"] = ccds
|
| 481 |
+
if "msa" in all_chain_features[chain_id]:
|
| 482 |
+
all_chain_features[chain_id]["msa"] = all_chain_features[chain_id]["msa"][:, conf_mask]
|
| 483 |
+
all_chain_features[chain_id]["deletion_matrix"] = \
|
| 484 |
+
all_chain_features[chain_id]["deletion_matrix"][:, conf_mask]
|
| 485 |
+
if "msa_all_seq" in all_chain_features[chain_id]:
|
| 486 |
+
all_chain_features[chain_id]["msa_all_seq"] = all_chain_features[chain_id]["msa_all_seq"][:, conf_mask]
|
| 487 |
+
all_chain_features[chain_id]["deletion_matrix_all_seq"] = \
|
| 488 |
+
all_chain_features[chain_id]["deletion_matrix_all_seq"][:, conf_mask]
|
| 489 |
+
# Remove Unused Chains
|
| 490 |
+
for chain_id in list(all_chain_features.keys()):
|
| 491 |
+
if chain_id not in crop_chains:
|
| 492 |
+
all_chain_features.pop(chain_id, None)
|
| 493 |
+
|
| 494 |
+
return all_chain_features
|
| 495 |
+
|
| 496 |
+
def _spatial_crop(self, all_chain_features):
|
| 497 |
+
|
| 498 |
+
ordered_chain_ids = list(all_chain_features.keys())
|
| 499 |
+
atom_id_to_ccd_id = []
|
| 500 |
+
atom_id_to_ccd_chunk_sizes = []
|
| 501 |
+
atom_id_to_ccd = []
|
| 502 |
+
|
| 503 |
+
ccd_all_id = 0
|
| 504 |
+
for chain_id in ordered_chain_ids:
|
| 505 |
+
for ccd, chunk_size_this_ccd in zip(
|
| 506 |
+
all_chain_features[chain_id]["ccds"],
|
| 507 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
|
| 508 |
+
):
|
| 509 |
+
atom_id_to_ccd_id += [ccd_all_id] * chunk_size_this_ccd
|
| 510 |
+
atom_id_to_ccd_chunk_sizes += [chunk_size_this_ccd] * chunk_size_this_ccd
|
| 511 |
+
atom_id_to_ccd += [ccd] * chunk_size_this_ccd
|
| 512 |
+
ccd_all_id += 1
|
| 513 |
+
|
| 514 |
+
to_sum_atom = 0
|
| 515 |
+
to_sum_token = 0
|
| 516 |
+
x_gt = np.concatenate([all_chain_features[chain_id]["x_gt"] for chain_id in ordered_chain_ids], axis=0)
|
| 517 |
+
|
| 518 |
+
spatial_crop_ratio = 0.3 if not self.inference_mode else 0
|
| 519 |
+
if random.random() < spatial_crop_ratio or len(ordered_chain_ids) == 1:
|
| 520 |
+
x_gt_sel = random.choice(x_gt)[None]
|
| 521 |
+
else:
|
| 522 |
+
asym_id = np.array(reduce(add, [
|
| 523 |
+
[asym_id] * len(all_chain_features[chain_id]["x_gt"])
|
| 524 |
+
for asym_id, chain_id in enumerate(ordered_chain_ids)
|
| 525 |
+
]))
|
| 526 |
+
chain_same = asym_id[None] * asym_id[:, None]
|
| 527 |
+
dist = np.linalg.norm(x_gt[:, None] - x_gt[None], axis=-1)
|
| 528 |
+
|
| 529 |
+
dist = dist + chain_same * 100
|
| 530 |
+
mask = np.any(dist < 4, axis=-1)
|
| 531 |
+
if sum(mask) > 0:
|
| 532 |
+
x_gt_ = x_gt[mask]
|
| 533 |
+
x_gt_sel = random.choice(x_gt)[None]
|
| 534 |
+
dist = np.linalg.norm(x_gt - x_gt_sel, axis=-1)
|
| 535 |
+
atom_idxs = np.argsort(dist)
|
| 536 |
+
select_ccds_idx = []
|
| 537 |
+
for atom_idx in atom_idxs:
|
| 538 |
+
ccd_idx = atom_id_to_ccd_id[atom_idx]
|
| 539 |
+
ccd_chunk_size = atom_id_to_ccd_chunk_sizes[atom_idx]
|
| 540 |
+
ccd_this_atom = atom_id_to_ccd[atom_idx]
|
| 541 |
+
if ccd_idx in select_ccds_idx:
|
| 542 |
+
continue
|
| 543 |
+
if to_sum_atom + ccd_chunk_size > self.atom_crop_size:
|
| 544 |
+
break
|
| 545 |
+
to_add_token = 1 if rc.is_standard(ccd_this_atom) else ccd_chunk_size
|
| 546 |
+
if to_sum_token + to_add_token > self.token_crop_size:
|
| 547 |
+
break
|
| 548 |
+
select_ccds_idx.append(ccd_idx)
|
| 549 |
+
to_sum_atom += ccd_chunk_size
|
| 550 |
+
to_sum_token += to_add_token
|
| 551 |
+
ccd_all_id = 0
|
| 552 |
+
crop_chains = []
|
| 553 |
+
for chain_id in ordered_chain_ids:
|
| 554 |
+
conformer_used_mask = []
|
| 555 |
+
atom_used_mask = []
|
| 556 |
+
ccds = []
|
| 557 |
+
for ccd, chunk_size_this_ccd in zip(
|
| 558 |
+
all_chain_features[chain_id]["ccds"],
|
| 559 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
|
| 560 |
+
):
|
| 561 |
+
if ccd_all_id in select_ccds_idx:
|
| 562 |
+
ccds.append(ccd)
|
| 563 |
+
if chain_id not in crop_chains:
|
| 564 |
+
crop_chains.append(chain_id)
|
| 565 |
+
conformer_used_mask.append(ccd_all_id in select_ccds_idx)
|
| 566 |
+
atom_used_mask += [ccd_all_id in select_ccds_idx] * chunk_size_this_ccd
|
| 567 |
+
ccd_all_id += 1
|
| 568 |
+
conf_mask = np.array(conformer_used_mask).astype(np.bool_)
|
| 569 |
+
atom_mask = np.array(atom_used_mask).astype(np.bool_)
|
| 570 |
+
# Update All Chain Features
|
| 571 |
+
all_chain_features[chain_id]["x_gt"] = all_chain_features[chain_id]["x_gt"][atom_mask]
|
| 572 |
+
all_chain_features[chain_id]["atom_id_to_conformer_atom_id"] = \
|
| 573 |
+
all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][atom_mask]
|
| 574 |
+
all_chain_features[chain_id]["restype"] = all_chain_features[chain_id]["restype"][conf_mask]
|
| 575 |
+
all_chain_features[chain_id]["residue_index"] = all_chain_features[chain_id]["residue_index"][conf_mask]
|
| 576 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = \
|
| 577 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"][conf_mask]
|
| 578 |
+
all_chain_features[chain_id]["ccds"] = ccds
|
| 579 |
+
if "msa" in all_chain_features[chain_id]:
|
| 580 |
+
all_chain_features[chain_id]["msa"] = all_chain_features[chain_id]["msa"][:, conf_mask]
|
| 581 |
+
all_chain_features[chain_id]["deletion_matrix"] = \
|
| 582 |
+
all_chain_features[chain_id]["deletion_matrix"][:, conf_mask]
|
| 583 |
+
if "msa_all_seq" in all_chain_features[chain_id]:
|
| 584 |
+
all_chain_features[chain_id]["msa_all_seq"] = all_chain_features[chain_id]["msa_all_seq"][:, conf_mask]
|
| 585 |
+
all_chain_features[chain_id]["deletion_matrix_all_seq"] = \
|
| 586 |
+
all_chain_features[chain_id]["deletion_matrix_all_seq"][:, conf_mask]
|
| 587 |
+
# Remove Unused Chains
|
| 588 |
+
for chain_id in list(all_chain_features.keys()):
|
| 589 |
+
if chain_id not in crop_chains:
|
| 590 |
+
all_chain_features.pop(chain_id, None)
|
| 591 |
+
return all_chain_features
|
| 592 |
+
|
| 593 |
+
def crop_all_chain_features(self, all_chain_features, infer_meta_data):
|
| 594 |
+
# all_chain_features = self._spatial_crop(all_chain_features)
|
| 595 |
+
all_chain_features = self._spatial_crop_v2(all_chain_features, infer_meta_data)
|
| 596 |
+
return all_chain_features, infer_meta_data
|
| 597 |
+
|
| 598 |
+
def _make_pocket_features(self, all_chain_features):
|
| 599 |
+
# minimium distance 6-12
|
| 600 |
+
all_chain_ids = list(all_chain_features.keys())
|
| 601 |
+
|
| 602 |
+
for chain_id in all_chain_ids:
|
| 603 |
+
all_chain_features[chain_id]["pocket_res_feat"] = np.zeros(
|
| 604 |
+
[len(all_chain_features[chain_id]["ccds"])], dtype=np.bool_)
|
| 605 |
+
|
| 606 |
+
ligand_chain_ids = [i for i in all_chain_ids if i.isdigit()]
|
| 607 |
+
receptor_chain_ids = [i for i in all_chain_ids if not i.isdigit()]
|
| 608 |
+
|
| 609 |
+
use_pocket = random.random() < 0.5
|
| 610 |
+
# TODO: Inference mode assign
|
| 611 |
+
if len(ligand_chain_ids) == 0 or len(receptor_chain_ids) == 0 or not use_pocket:
|
| 612 |
+
for chain_id in all_chain_ids:
|
| 613 |
+
all_chain_features[chain_id]["pocket_res_feat"] = all_chain_features[chain_id][
|
| 614 |
+
"pocket_res_feat"].astype(np.float32)
|
| 615 |
+
return all_chain_features
|
| 616 |
+
|
| 617 |
+
# Aug Part
|
| 618 |
+
for ligand_chain_id in ligand_chain_ids:
|
| 619 |
+
x_gt_ligand = all_chain_features[ligand_chain_id]["x_gt"]
|
| 620 |
+
|
| 621 |
+
# x_gt_mean = np.mean(x_gt_ligand, axis=0) + np.random.randn(3)
|
| 622 |
+
|
| 623 |
+
for receptor_chain_id in receptor_chain_ids:
|
| 624 |
+
x_gt_receptor = all_chain_features[receptor_chain_id]["x_gt"]
|
| 625 |
+
|
| 626 |
+
# dist = np.linalg.norm(x_gt_receptor - x_gt_mean[None], axis=-1)
|
| 627 |
+
# is_pocket_atom = (dist < (random.random() * 6 + 8)).astype(np.bool_)
|
| 628 |
+
is_pocket_atom = np.any(
|
| 629 |
+
np.linalg.norm(x_gt_receptor[:, None] - x_gt_ligand[None], axis=-1) < (random.random() * 6 + 6),
|
| 630 |
+
axis=-1
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
is_pocket_ccd = []
|
| 634 |
+
offset = 0
|
| 635 |
+
for chunk_size in all_chain_features[receptor_chain_id]["conformer_id_to_chunk_sizes"]:
|
| 636 |
+
is_pocket_ccd.append(np.any(is_pocket_atom[offset:offset + chunk_size]).item())
|
| 637 |
+
offset += chunk_size
|
| 638 |
+
is_pocket_ccd = np.array(is_pocket_ccd, dtype=np.bool_)
|
| 639 |
+
|
| 640 |
+
is_pocket_ccd = np.array([np.any(i).item() for i in is_pocket_ccd], dtype=np.bool_)
|
| 641 |
+
all_chain_features[receptor_chain_id]["pocket_res_feat"] = all_chain_features[receptor_chain_id][
|
| 642 |
+
"pocket_res_feat"] | is_pocket_ccd
|
| 643 |
+
|
| 644 |
+
for chain_id in all_chain_ids:
|
| 645 |
+
all_chain_features[chain_id]["pocket_res_feat"] = all_chain_features[chain_id]["pocket_res_feat"].astype(
|
| 646 |
+
np.float32)
|
| 647 |
+
|
| 648 |
+
return all_chain_features
|
| 649 |
+
|
| 650 |
+
def _make_ccd_features(self, raw_feats, infer_meta_data):
|
| 651 |
+
CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
|
| 652 |
+
ccds = raw_feats["ccds"]
|
| 653 |
+
atom_id_to_conformer_atom_id = raw_feats["atom_id_to_conformer_atom_id"]
|
| 654 |
+
conformer_id_to_chunk_sizes = raw_feats["conformer_id_to_chunk_sizes"]
|
| 655 |
+
|
| 656 |
+
# Atomwise
|
| 657 |
+
atom_id_to_conformer_id = []
|
| 658 |
+
atom_id_to_token_id = []
|
| 659 |
+
ref_feat = []
|
| 660 |
+
|
| 661 |
+
# Tokenwise
|
| 662 |
+
s_mask = []
|
| 663 |
+
token_id_to_conformer_id = []
|
| 664 |
+
token_id_to_chunk_sizes = []
|
| 665 |
+
token_id_to_centre_atom_id = []
|
| 666 |
+
token_id_to_pseudo_beta_atom_id = []
|
| 667 |
+
|
| 668 |
+
token_id = 0
|
| 669 |
+
atom_id = 0
|
| 670 |
+
for conf_id, (ccd, ccd_atoms) in enumerate(zip(ccds, conformer_id_to_chunk_sizes)):
|
| 671 |
+
conf_meta_data = CONF_META_DATA[ccd]
|
| 672 |
+
# UNK Conformer
|
| 673 |
+
if rc.is_unk(ccd):
|
| 674 |
+
s_mask.append(0)
|
| 675 |
+
token_id_to_chunk_sizes.append(0)
|
| 676 |
+
token_id_to_conformer_id.append(conf_id)
|
| 677 |
+
token_id_to_centre_atom_id.append(-1)
|
| 678 |
+
token_id_to_pseudo_beta_atom_id.append(-1)
|
| 679 |
+
token_id += 1
|
| 680 |
+
# Standard Residue
|
| 681 |
+
elif rc.is_standard(ccd):
|
| 682 |
+
inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
|
| 683 |
+
atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 684 |
+
ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
|
| 685 |
+
token_id_to_conformer_id.append(conf_id)
|
| 686 |
+
token_id_to_chunk_sizes.append(ccd_atoms.item())
|
| 687 |
+
s_mask.append(1)
|
| 688 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 689 |
+
# Update Atomwise Features
|
| 690 |
+
atom_id_to_conformer_id.append(conf_id)
|
| 691 |
+
atom_id_to_token_id.append(token_id)
|
| 692 |
+
# Update special atom ids
|
| 693 |
+
if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
|
| 694 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 695 |
+
if atom_name == rc.standard_ccd_to_token_pseudo_beta_atom_name[ccd]:
|
| 696 |
+
token_id_to_pseudo_beta_atom_id.append(atom_id)
|
| 697 |
+
atom_id += 1
|
| 698 |
+
token_id += 1
|
| 699 |
+
# Nonestandard Residue & Ligand
|
| 700 |
+
else:
|
| 701 |
+
inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
|
| 702 |
+
atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 703 |
+
ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
|
| 704 |
+
# ref_pos_new.append(conf_meta_data["ref_pos_new"][:, inner_atom_idx])
|
| 705 |
+
for atom_id_this_ccd, atom_name in enumerate(atom_names):
|
| 706 |
+
# Update Atomwise Features
|
| 707 |
+
atom_id_to_conformer_id.append(conf_id)
|
| 708 |
+
atom_id_to_token_id.append(token_id)
|
| 709 |
+
# Update Tokenwise Features
|
| 710 |
+
token_id_to_chunk_sizes.append(1)
|
| 711 |
+
token_id_to_conformer_id.append(conf_id)
|
| 712 |
+
s_mask.append(1)
|
| 713 |
+
token_id_to_centre_atom_id.append(atom_id)
|
| 714 |
+
token_id_to_pseudo_beta_atom_id.append(atom_id)
|
| 715 |
+
atom_id += 1
|
| 716 |
+
token_id += 1
|
| 717 |
+
|
| 718 |
+
if len(ref_feat) > 1:
|
| 719 |
+
ref_feat = np.concatenate(ref_feat, axis=0).astype(np.float32)
|
| 720 |
+
else:
|
| 721 |
+
ref_feat = ref_feat[0].astype(np.float32)
|
| 722 |
+
|
| 723 |
+
features = {
|
| 724 |
+
# Atomwise
|
| 725 |
+
"atom_id_to_conformer_id": np.array(atom_id_to_conformer_id, dtype=np.int64),
|
| 726 |
+
"atom_id_to_token_id": np.array(atom_id_to_token_id, dtype=np.int64),
|
| 727 |
+
"ref_feat": ref_feat,
|
| 728 |
+
# Tokewise
|
| 729 |
+
"token_id_to_conformer_id": np.array(token_id_to_conformer_id, dtype=np.int64),
|
| 730 |
+
"s_mask": np.array(s_mask, dtype=np.int64),
|
| 731 |
+
"token_id_to_centre_atom_id": np.array(token_id_to_centre_atom_id, dtype=np.int64),
|
| 732 |
+
"token_id_to_pseudo_beta_atom_id": np.array(token_id_to_pseudo_beta_atom_id, dtype=np.int64),
|
| 733 |
+
"token_id_to_chunk_sizes": np.array(token_id_to_chunk_sizes, dtype=np.int64),
|
| 734 |
+
}
|
| 735 |
+
features["ref_pos"] = features["ref_feat"][..., :3]
|
| 736 |
+
return features
|
| 737 |
+
|
| 738 |
+
def pair_and_merge(self, all_chain_features, infer_meta_data):
|
| 739 |
+
CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"] # Dict
|
| 740 |
+
CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
|
| 741 |
+
ASYM_ID = infer_meta_data["ASYM_ID"]
|
| 742 |
+
homo_feats = {}
|
| 743 |
+
|
| 744 |
+
# Create Aug Pocket Feature
|
| 745 |
+
all_chain_features = self._make_pocket_features(all_chain_features)
|
| 746 |
+
|
| 747 |
+
all_chain_ids = list(all_chain_features.keys())
|
| 748 |
+
if len(all_chain_ids) == 1 and CHAIN_CLASS[all_chain_ids[0]] == "ligand":
|
| 749 |
+
ordered_chain_ids = all_chain_ids
|
| 750 |
+
raw_feats = all_chain_features[all_chain_ids[0]]
|
| 751 |
+
raw_feats["msa"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
|
| 752 |
+
raw_feats["deletion_matrix"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
|
| 753 |
+
keys = list(raw_feats.keys())
|
| 754 |
+
|
| 755 |
+
for feature_name in keys:
|
| 756 |
+
if feature_name not in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index",
|
| 757 |
+
"conformer_id_to_chunk_sizes", "restype", "is_protein", "is_short_poly",
|
| 758 |
+
"is_ligand",
|
| 759 |
+
"asym_id", "sym_id", "entity_id", "msa", "deletion_matrix", "ccds",
|
| 760 |
+
"pocket_res_feat", "key_res_feat", "is_key_res"]:
|
| 761 |
+
raw_feats.pop(feature_name)
|
| 762 |
+
|
| 763 |
+
# Update Profile and Deletion Mean
|
| 764 |
+
msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
|
| 765 |
+
raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
|
| 766 |
+
del msa_one_hot
|
| 767 |
+
raw_feats["deletion_mean"] = (torch.atan(
|
| 768 |
+
torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
|
| 769 |
+
) * (2.0 / torch.pi)).numpy()
|
| 770 |
+
else:
|
| 771 |
+
|
| 772 |
+
for chain_id in list(all_chain_features.keys()):
|
| 773 |
+
homo_feats[chain_id] = {
|
| 774 |
+
"asym_id": copy.deepcopy(all_chain_features[chain_id]["asym_id"]),
|
| 775 |
+
"sym_id": copy.deepcopy(all_chain_features[chain_id]["sym_id"]),
|
| 776 |
+
"entity_id": copy.deepcopy(all_chain_features[chain_id]["entity_id"]),
|
| 777 |
+
}
|
| 778 |
+
for chain_id in list(all_chain_features.keys()):
|
| 779 |
+
homo_feats[chain_id]["chain_class"] = all_chain_features[chain_id].pop("chain_class")
|
| 780 |
+
homo_feats[chain_id]["sequence_3"] = all_chain_features[chain_id].pop("sequence_3")
|
| 781 |
+
homo_feats[chain_id]["msa"] = all_chain_features[chain_id].pop("msa")
|
| 782 |
+
homo_feats[chain_id]["deletion_matrix"] = all_chain_features[chain_id].pop("deletion_matrix")
|
| 783 |
+
if "msa_all_seq" in all_chain_features[chain_id]:
|
| 784 |
+
homo_feats[chain_id]["msa_all_seq"] = all_chain_features[chain_id].pop("msa_all_seq")
|
| 785 |
+
homo_feats[chain_id]["deletion_matrix_all_seq"] = all_chain_features[chain_id].pop(
|
| 786 |
+
"deletion_matrix_all_seq")
|
| 787 |
+
homo_feats[chain_id]["msa_species_identifiers_all_seq"] = all_chain_features[chain_id].pop(
|
| 788 |
+
"msa_species_identifiers_all_seq")
|
| 789 |
+
|
| 790 |
+
# Initial raw feats with merged homo feats
|
| 791 |
+
raw_feats = pair_and_merge(homo_feats, is_homomer_or_monomer=False)
|
| 792 |
+
|
| 793 |
+
# Update Profile and Deletion Mean
|
| 794 |
+
msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
|
| 795 |
+
raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
|
| 796 |
+
del msa_one_hot
|
| 797 |
+
raw_feats["deletion_mean"] = (torch.atan(
|
| 798 |
+
torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
|
| 799 |
+
) * (2.0 / torch.pi)).numpy()
|
| 800 |
+
|
| 801 |
+
# Merge no homo feats according to asym_id
|
| 802 |
+
ordered_asym_ids = []
|
| 803 |
+
for i in raw_feats["asym_id"]:
|
| 804 |
+
if i not in ordered_asym_ids:
|
| 805 |
+
ordered_asym_ids.append(i)
|
| 806 |
+
ordered_chain_ids = [ASYM_ID[i] for i in ordered_asym_ids]
|
| 807 |
+
for feature_name in ["chain_class", "sequence_3", "assembly_num_chains", "entity_mask", "seq_length",
|
| 808 |
+
"num_alignments"]:
|
| 809 |
+
raw_feats.pop(feature_name, None)
|
| 810 |
+
for feature_name in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index", "conformer_id_to_chunk_sizes",
|
| 811 |
+
"restype", "is_protein", "is_short_poly", "is_ligand", "pocket_res_feat",
|
| 812 |
+
"key_res_feat", "is_key_res"]:
|
| 813 |
+
raw_feats[feature_name] = np.concatenate([
|
| 814 |
+
all_chain_features[chain_id].pop(feature_name) for chain_id in ordered_chain_ids
|
| 815 |
+
], axis=0)
|
| 816 |
+
|
| 817 |
+
# Conformerwise Chain Class
|
| 818 |
+
CHAIN_CLASS_NEW = []
|
| 819 |
+
for chain_id in ordered_chain_ids:
|
| 820 |
+
CHAIN_CLASS_NEW += [CHAIN_CLASS[chain_id]] * len(all_chain_features[chain_id]["ccds"])
|
| 821 |
+
infer_meta_data["CHAIN_CLASS"] = CHAIN_CLASS_NEW
|
| 822 |
+
|
| 823 |
+
raw_feats["ccds"] = reduce(add, [all_chain_features[chain_id].pop("ccds") for chain_id in ordered_chain_ids])
|
| 824 |
+
|
| 825 |
+
# Create Atomwise and Tokenwise Features
|
| 826 |
+
raw_feats.update(self._make_ccd_features(raw_feats, infer_meta_data))
|
| 827 |
+
|
| 828 |
+
asym_id_conformerwise = copy.deepcopy(raw_feats["asym_id"])
|
| 829 |
+
residue_index_conformerwise = copy.deepcopy(raw_feats["residue_index"])
|
| 830 |
+
|
| 831 |
+
# Conformerwise to Tokenwise
|
| 832 |
+
token_id_to_conformer_id = raw_feats["token_id_to_conformer_id"]
|
| 833 |
+
for key in ["is_protein", "is_short_poly", "is_ligand", "residue_index", "restype", "asym_id", "entity_id",
|
| 834 |
+
"sym_id", "deletion_mean", "profile", "pocket_res_feat", "key_res_feat", "is_key_res"]:
|
| 835 |
+
raw_feats[key] = raw_feats[key][token_id_to_conformer_id]
|
| 836 |
+
for key in ["msa", "deletion_matrix"]:
|
| 837 |
+
if key in raw_feats:
|
| 838 |
+
raw_feats[key] = raw_feats[key][:, token_id_to_conformer_id]
|
| 839 |
+
###################################################
|
| 840 |
+
# Centre Random Augmentation of ref pos #
|
| 841 |
+
###################################################
|
| 842 |
+
raw_feats["ref_pos"] = centre_random_augmentation_np_apply(
|
| 843 |
+
raw_feats["ref_pos"], raw_feats["atom_id_to_token_id"]).astype(np.float32)
|
| 844 |
+
raw_feats["ref_feat"][:, :3] = raw_feats["ref_pos"]
|
| 845 |
+
|
| 846 |
+
###################################################
|
| 847 |
+
# Create token pair features #
|
| 848 |
+
###################################################
|
| 849 |
+
no_token = len(raw_feats["token_id_to_conformer_id"])
|
| 850 |
+
token_bonds = np.zeros([no_token, no_token], dtype=np.float32)
|
| 851 |
+
rel_tok_feat = np.zeros([no_token, no_token, 42], dtype=np.float32)
|
| 852 |
+
batch_ref_pos = np.zeros([32, no_token, 3], dtype=np.float32)
|
| 853 |
+
offset = 0
|
| 854 |
+
atom_offset = 0
|
| 855 |
+
for ccd, len_atoms in zip(
|
| 856 |
+
raw_feats["ccds"],
|
| 857 |
+
raw_feats["conformer_id_to_chunk_sizes"]
|
| 858 |
+
):
|
| 859 |
+
if rc.is_standard(ccd) or rc.is_unk(ccd):
|
| 860 |
+
offset += 1
|
| 861 |
+
else:
|
| 862 |
+
len_atoms = len_atoms.item()
|
| 863 |
+
inner_atom_idx = raw_feats["atom_id_to_conformer_atom_id"][atom_offset:atom_offset + len_atoms]
|
| 864 |
+
batch_ref_pos[:, offset:offset + len_atoms] = CONF_META_DATA[ccd]["batch_ref_pos"][:, inner_atom_idx]
|
| 865 |
+
token_bonds[offset:offset + len_atoms, offset:offset + len_atoms] = \
|
| 866 |
+
CONF_META_DATA[ccd]["token_bonds"][inner_atom_idx][:, inner_atom_idx]
|
| 867 |
+
rel_tok_feat[offset:offset + len_atoms, offset:offset + len_atoms] = \
|
| 868 |
+
CONF_META_DATA[ccd]["rel_tok_feat"][inner_atom_idx][:, inner_atom_idx]
|
| 869 |
+
offset += len_atoms
|
| 870 |
+
atom_offset += len_atoms
|
| 871 |
+
raw_feats["token_bonds"] = token_bonds.astype(np.float32)
|
| 872 |
+
raw_feats["token_bonds_feature"] = token_bonds.astype(np.float32)
|
| 873 |
+
raw_feats["rel_tok_feat"] = rel_tok_feat.astype(np.float32)
|
| 874 |
+
raw_feats["batch_ref_pos"] = batch_ref_pos.astype(np.float32)
|
| 875 |
+
###################################################
|
| 876 |
+
# Charility Augmentation #
|
| 877 |
+
###################################################
|
| 878 |
+
if not self.inference_mode:
|
| 879 |
+
# TODO Charility probs
|
| 880 |
+
charility_seed = random.random()
|
| 881 |
+
if charility_seed < 0.1:
|
| 882 |
+
ref_chirality = raw_feats["ref_feat"][:, 158:161]
|
| 883 |
+
ref_chirality_replace = np.zeros_like(ref_chirality)
|
| 884 |
+
ref_chirality_replace[:, 2] = 1
|
| 885 |
+
|
| 886 |
+
is_ligand_atom = raw_feats["is_ligand"][raw_feats["atom_id_to_token_id"]]
|
| 887 |
+
remove_charility = (np.random.randint(0, 2, [len(is_ligand_atom)]) * is_ligand_atom).astype(
|
| 888 |
+
np.bool_)
|
| 889 |
+
ref_chirality = np.where(remove_charility[:, None], ref_chirality_replace, ref_chirality)
|
| 890 |
+
raw_feats["ref_feat"][:, 158:161] = ref_chirality
|
| 891 |
+
|
| 892 |
+
# MASKS
|
| 893 |
+
raw_feats["x_exists"] = np.ones_like(raw_feats["x_gt"][..., 0]).astype(np.float32)
|
| 894 |
+
raw_feats["a_mask"] = raw_feats["x_exists"]
|
| 895 |
+
raw_feats["s_mask"] = np.ones_like(raw_feats["asym_id"]).astype(np.float32)
|
| 896 |
+
raw_feats["ref_space_uid"] = raw_feats["atom_id_to_conformer_id"]
|
| 897 |
+
|
| 898 |
+
# Write Infer Meta Data
|
| 899 |
+
infer_meta_data["ccds"] = raw_feats.pop("ccds")
|
| 900 |
+
infer_meta_data["atom_id_to_conformer_atom_id"] = raw_feats.pop("atom_id_to_conformer_atom_id")
|
| 901 |
+
infer_meta_data["residue_index"] = residue_index_conformerwise
|
| 902 |
+
infer_meta_data["asym_id"] = asym_id_conformerwise
|
| 903 |
+
infer_meta_data["conformer_id_to_chunk_sizes"] = raw_feats.pop("conformer_id_to_chunk_sizes")
|
| 904 |
+
|
| 905 |
+
return raw_feats, infer_meta_data
|
| 906 |
+
|
| 907 |
+
def make_feats(self, tensors):
|
| 908 |
+
# Target Feat
|
| 909 |
+
tensors["target_feat"] = torch.cat([
|
| 910 |
+
F.one_hot(tensors["restype"].long(), 32).float(),
|
| 911 |
+
tensors["profile"].float(),
|
| 912 |
+
tensors["deletion_mean"][..., None].float()
|
| 913 |
+
], dim=-1)
|
| 914 |
+
|
| 915 |
+
# MSA Feat
|
| 916 |
+
inds = [0] + torch.randperm(len(tensors["msa"]))[:127].tolist()
|
| 917 |
+
|
| 918 |
+
tensors["msa"] = tensors["msa"][inds]
|
| 919 |
+
tensors["deletion_matrix"] = tensors["deletion_matrix"][inds]
|
| 920 |
+
|
| 921 |
+
has_deletion = torch.clamp(tensors["deletion_matrix"].float(), min=0., max=1.)
|
| 922 |
+
pi = torch.acos(torch.zeros(1, device=tensors["deletion_matrix"].device)) * 2
|
| 923 |
+
deletion_value = (torch.atan(tensors["deletion_matrix"] / 3.) * (2. / pi))
|
| 924 |
+
tensors["msa_feat"] = torch.cat([
|
| 925 |
+
F.one_hot(tensors["msa"].long(), 32).float(),
|
| 926 |
+
has_deletion[..., None].float(),
|
| 927 |
+
deletion_value[..., None].float(),
|
| 928 |
+
], dim=-1)
|
| 929 |
+
tensors.pop("msa", None)
|
| 930 |
+
tensors.pop("deletion_mean", None)
|
| 931 |
+
tensors.pop("profile", None)
|
| 932 |
+
tensors.pop("deletion_matrix", None)
|
| 933 |
+
|
| 934 |
+
return tensors
|
| 935 |
+
|
| 936 |
+
def _make_token_bonds(self, tensors):
|
| 937 |
+
# Get Polymer-Ligand & Ligand-Ligand Within Conformer Token Bond
|
| 938 |
+
|
| 939 |
+
# Atomwise asym_id
|
| 940 |
+
asym_id = tensors["asym_id"][tensors["atom_id_to_token_id"]]
|
| 941 |
+
is_ligand = tensors["is_ligand"][tensors["atom_id_to_token_id"]]
|
| 942 |
+
|
| 943 |
+
x_gt = tensors["x_gt"]
|
| 944 |
+
a_mask = tensors["a_mask"]
|
| 945 |
+
|
| 946 |
+
# Get
|
| 947 |
+
atom_id_to_token_id = tensors["atom_id_to_token_id"]
|
| 948 |
+
|
| 949 |
+
num_token = len(tensors["asym_id"])
|
| 950 |
+
between_conformer_token_bonds = torch.zeros([num_token, num_token])
|
| 951 |
+
|
| 952 |
+
# create chainwise feature
|
| 953 |
+
asym_id_chain = []
|
| 954 |
+
asym_id_atom_offset = []
|
| 955 |
+
asym_id_is_ligand = []
|
| 956 |
+
for atom_offset, (a_id, i_id) in enumerate(zip(asym_id.tolist(), is_ligand.tolist())):
|
| 957 |
+
if len(asym_id_chain) == 0 or asym_id_chain[-1] != a_id:
|
| 958 |
+
asym_id_chain.append(a_id)
|
| 959 |
+
asym_id_atom_offset.append(atom_offset)
|
| 960 |
+
asym_id_is_ligand.append(i_id)
|
| 961 |
+
|
| 962 |
+
len_asym_id_chain = len(asym_id_chain)
|
| 963 |
+
if len_asym_id_chain >= 2:
|
| 964 |
+
for i in range(0, len_asym_id_chain - 1):
|
| 965 |
+
asym_id_i = asym_id_chain[i]
|
| 966 |
+
mask_i = asym_id == asym_id_i
|
| 967 |
+
x_gt_i = x_gt[mask_i]
|
| 968 |
+
a_mask_i = a_mask[mask_i]
|
| 969 |
+
for j in range(i + 1, len_asym_id_chain):
|
| 970 |
+
if not bool(asym_id_is_ligand[i]) and not bool(asym_id_is_ligand[j]):
|
| 971 |
+
continue
|
| 972 |
+
asym_id_j = asym_id_chain[j]
|
| 973 |
+
mask_j = asym_id == asym_id_j
|
| 974 |
+
x_gt_j = x_gt[mask_j]
|
| 975 |
+
a_mask_j = a_mask[mask_j]
|
| 976 |
+
dis_ij = torch.norm(x_gt_i[:, None, :] - x_gt_j[None, :, :], dim=-1)
|
| 977 |
+
dis_ij = dis_ij + (1 - a_mask_i[:, None] * a_mask_j[None]) * 1000
|
| 978 |
+
if torch.min(dis_ij) < self.token_bond_threshold:
|
| 979 |
+
ij = torch.argmin(dis_ij).item()
|
| 980 |
+
l_j = len(x_gt_j)
|
| 981 |
+
atom_i = int(ij // l_j) # raw
|
| 982 |
+
atom_j = int(ij % l_j) # col
|
| 983 |
+
global_atom_i = atom_i + asym_id_atom_offset[i]
|
| 984 |
+
global_atom_j = atom_j + asym_id_atom_offset[j]
|
| 985 |
+
token_i = atom_id_to_token_id[global_atom_i]
|
| 986 |
+
token_j = atom_id_to_token_id[global_atom_j]
|
| 987 |
+
|
| 988 |
+
between_conformer_token_bonds[token_i, token_j] = 1
|
| 989 |
+
between_conformer_token_bonds[token_j, token_i] = 1
|
| 990 |
+
token_bond_seed = random.random()
|
| 991 |
+
tensors["token_bonds"] = tensors["token_bonds"] + between_conformer_token_bonds
|
| 992 |
+
# Docking Indicate Token Bond
|
| 993 |
+
if token_bond_seed >= 0:
|
| 994 |
+
tensors["token_bonds_feature"] = tensors["token_bonds"]
|
| 995 |
+
return tensors
|
| 996 |
+
|
| 997 |
+
def _pad_to_size(self, tensors):
|
| 998 |
+
|
| 999 |
+
to_pad_atom = self.atom_crop_size - len(tensors["x_gt"])
|
| 1000 |
+
to_pad_token = self.token_crop_size - len(tensors["residue_index"])
|
| 1001 |
+
if to_pad_token > 0:
|
| 1002 |
+
for k in ["restype", "residue_index", "is_protein", "is_short_poly", "is_ligand", "is_key_res",
|
| 1003 |
+
"asym_id", "entity_id", "sym_id", "token_id_to_conformer_id", "s_mask",
|
| 1004 |
+
"token_id_to_centre_atom_id", "token_id_to_pseudo_beta_atom_id", "token_id_to_chunk_sizes",
|
| 1005 |
+
"pocket_res_feat"]:
|
| 1006 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token])
|
| 1007 |
+
for k in ["target_feat", "msa_feat", "batch_ref_pos", "key_res_feat"]:
|
| 1008 |
+
if k in tensors:
|
| 1009 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token])
|
| 1010 |
+
for k in ["token_bonds", "token_bonds_feature"]:
|
| 1011 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token, 0, to_pad_token])
|
| 1012 |
+
for k in ["rel_tok_feat"]:
|
| 1013 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token, 0, to_pad_token])
|
| 1014 |
+
if to_pad_atom > 0:
|
| 1015 |
+
for k in ["a_mask", "x_exists", "atom_id_to_conformer_id", "atom_id_to_token_id", "ref_space_uid"]:
|
| 1016 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom])
|
| 1017 |
+
for k in ["x_gt", "ref_feat", "ref_pos"]: # , "ref_pos_new"
|
| 1018 |
+
tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_atom])
|
| 1019 |
+
# for k in ["z_mask"]: # , "ref_pos_new"
|
| 1020 |
+
# tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
|
| 1021 |
+
# for k in ["conformer_mask_atom"]:
|
| 1022 |
+
# tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
|
| 1023 |
+
# for k in ["rel_token_feat_atom"]:
|
| 1024 |
+
# tensors[k] = torch.nn.functional.pad(tensors[k], [0,0,0, to_pad_atom, 0, to_pad_atom])
|
| 1025 |
+
# rel_token_feat_atom
|
| 1026 |
+
return tensors
|
| 1027 |
+
|
| 1028 |
+
def get_template_feat(self, tensors):
|
| 1029 |
+
x_gt = tensors["x_gt"][tensors["token_id_to_pseudo_beta_atom_id"]]
|
| 1030 |
+
z_mask = tensors["z_mask"]
|
| 1031 |
+
asym_id = tensors["asym_id"]
|
| 1032 |
+
is_protein = tensors["is_protein"]
|
| 1033 |
+
chain_same = (asym_id[None] == asym_id[:, None]).float()
|
| 1034 |
+
protein2d = is_protein[None] * is_protein[:, None]
|
| 1035 |
+
dgram = dgram_from_positions(x_gt)
|
| 1036 |
+
dgram = dgram * protein2d[..., None] * z_mask[..., None]
|
| 1037 |
+
|
| 1038 |
+
# if not self.inference_mode:
|
| 1039 |
+
# bert_mask = torch.rand([len(x_gt)]) > random.random() * 0.4
|
| 1040 |
+
# asym_ids = list(set(asym_id.tolist()))
|
| 1041 |
+
# used_asym_ids = []
|
| 1042 |
+
# for a in asym_ids:
|
| 1043 |
+
# if random.random() > 0.6:
|
| 1044 |
+
# used_asym_ids.append(a)
|
| 1045 |
+
# if len(used_asym_ids) > 0:
|
| 1046 |
+
# used_asym_ids = torch.tensor(used_asym_ids)
|
| 1047 |
+
# chain_bert_mask = torch.any(asym_id[:, None] == used_asym_ids[None], dim=-1)
|
| 1048 |
+
# bert_mask = chain_bert_mask * bert_mask
|
| 1049 |
+
# else:
|
| 1050 |
+
# bert_mask = bert_mask * 0
|
| 1051 |
+
# template_pseudo_beta_mask = (bert_mask[None] * bert_mask[:, None]) * z_mask * protein2d
|
| 1052 |
+
# else:
|
| 1053 |
+
# template_pseudo_beta_mask = z_mask * protein2d
|
| 1054 |
+
template_pseudo_beta_mask = z_mask * protein2d
|
| 1055 |
+
# template_pseudo_beta_mask = protein2d * z_mask
|
| 1056 |
+
dgram = dgram * template_pseudo_beta_mask[..., None]
|
| 1057 |
+
templ_feat = torch.cat([dgram, template_pseudo_beta_mask[..., None]], dim=-1)
|
| 1058 |
+
tensors["templ_feat"] = templ_feat.float()[None]
|
| 1059 |
+
t_mask_seed = random.random()
|
| 1060 |
+
# Template Augmentation
|
| 1061 |
+
if self.inference_mode or t_mask_seed < 0.1:
|
| 1062 |
+
tensors["t_mask"] = torch.ones([len(tensors["templ_feat"])], dtype=torch.float32)
|
| 1063 |
+
else:
|
| 1064 |
+
tensors["t_mask"] = torch.zeros([len(tensors["templ_feat"])], dtype=torch.float32)
|
| 1065 |
+
|
| 1066 |
+
# TODO: No Template
|
| 1067 |
+
# if not self.inference_mode:
|
| 1068 |
+
# if random.random() < 0.5:
|
| 1069 |
+
# tensors["templ_feat"] *= 0
|
| 1070 |
+
return tensors
|
| 1071 |
+
|
| 1072 |
+
def transform(self, raw_feats):
|
| 1073 |
+
# np to tensor
|
| 1074 |
+
tensors = dict()
|
| 1075 |
+
for key in raw_feats.keys():
|
| 1076 |
+
tensors[key] = torch.from_numpy(raw_feats[key])
|
| 1077 |
+
# Make Target & MSA Feat
|
| 1078 |
+
tensors = self.make_feats(tensors)
|
| 1079 |
+
|
| 1080 |
+
# Make Token Bond Feat
|
| 1081 |
+
tensors = self._make_token_bonds(tensors)
|
| 1082 |
+
|
| 1083 |
+
# Padding
|
| 1084 |
+
if not self.inference_mode:
|
| 1085 |
+
tensors = self._pad_to_size(tensors)
|
| 1086 |
+
|
| 1087 |
+
# # Make Pocket Res Feat
|
| 1088 |
+
#
|
| 1089 |
+
# tensors["pocket_res_feat"] = torch.zeros([l], dtype=torch.float32)
|
| 1090 |
+
|
| 1091 |
+
# Make Key Res Feat
|
| 1092 |
+
# l = len(tensors["asym_id"])
|
| 1093 |
+
# tensors["key_res_feat"] = torch.zeros([l, 7], dtype=torch.float32)
|
| 1094 |
+
# tensors["key_res_feat"][:, 0] = 1.
|
| 1095 |
+
|
| 1096 |
+
# Mask
|
| 1097 |
+
tensors["z_mask"] = tensors["s_mask"][None] * tensors["s_mask"][:, None]
|
| 1098 |
+
|
| 1099 |
+
# Template
|
| 1100 |
+
tensors = self.get_template_feat(tensors)
|
| 1101 |
+
|
| 1102 |
+
# Correct Type
|
| 1103 |
+
is_short_poly = tensors.pop("is_short_poly")
|
| 1104 |
+
tensors["is_protein"] = tensors["is_protein"] + is_short_poly
|
| 1105 |
+
tensors["is_ligand"] = tensors["is_ligand"] - is_short_poly
|
| 1106 |
+
tensors["is_dna"] = torch.zeros_like(tensors["is_protein"])
|
| 1107 |
+
tensors["is_rna"] = torch.zeros_like(tensors["is_protein"])
|
| 1108 |
+
return tensors
|
| 1109 |
+
|
| 1110 |
+
def load(self, sample_id):
|
| 1111 |
+
all_chain_labels = load_pkl(os.path.join(self.samples_path, f"{sample_id}.pkl.gz"))
|
| 1112 |
+
|
| 1113 |
+
all_chain_features, infer_meta_data = self.load_all_chain_features(all_chain_labels)
|
| 1114 |
+
infer_meta_data["system_id"] = sample_id
|
| 1115 |
+
# if not self.inference_mode:
|
| 1116 |
+
all_chain_features, infer_meta_data = self.crop_all_chain_features(all_chain_features, infer_meta_data)
|
| 1117 |
+
raw_feats, infer_meta_data = self.pair_and_merge(all_chain_features, infer_meta_data)
|
| 1118 |
+
|
| 1119 |
+
tensors = self.transform(raw_feats)
|
| 1120 |
+
return tensors, infer_meta_data
|
| 1121 |
+
|
| 1122 |
+
def random_load(self):
|
| 1123 |
+
|
| 1124 |
+
sample_id = self.used_sample_ids[torch.multinomial(self.probabilities, 1).item()]
|
| 1125 |
+
print(sample_id)
|
| 1126 |
+
return self.load(sample_id)
|
| 1127 |
+
|
| 1128 |
+
def random_load_test(self):
|
| 1129 |
+
|
| 1130 |
+
sample_id = self.used_test_sample_ids[torch.multinomial(self.test_probabilities, 1).item()]
|
| 1131 |
+
print(sample_id)
|
| 1132 |
+
return self.load(sample_id)
|
| 1133 |
+
|
| 1134 |
+
def weighted_random_load(self):
|
| 1135 |
+
weight_seed = random.random()
|
| 1136 |
+
if weight_seed < 0.95:
|
| 1137 |
+
sample_id = self.used_sample_ids[torch.multinomial(self.probabilities, 1).item()]
|
| 1138 |
+
else:
|
| 1139 |
+
return self.random_load_mol_chunks()
|
| 1140 |
+
return self.load(sample_id)
|
| 1141 |
+
|
| 1142 |
+
def load_ligand(self, sample_id, chain_features):
|
| 1143 |
+
all_chain_features = {}
|
| 1144 |
+
CHAIN_META_DATA = {
|
| 1145 |
+
"spatial_crop_chain_ids": None,
|
| 1146 |
+
"chain_class": {},
|
| 1147 |
+
"chain_sequence_3s": {},
|
| 1148 |
+
"fake_ccds": [],
|
| 1149 |
+
}
|
| 1150 |
+
CONF_META_DATA = {}
|
| 1151 |
+
num_prev_fake_ccds = len(CHAIN_META_DATA["fake_ccds"])
|
| 1152 |
+
fake_ccd = f"{num_prev_fake_ccds:#>3}"
|
| 1153 |
+
|
| 1154 |
+
chain_features["msa"] = chain_features["restype"][None]
|
| 1155 |
+
chain_features["deletion_matrix"] = np.ones_like(chain_features["msa"])
|
| 1156 |
+
chain_features["ccds"] = [fake_ccd]
|
| 1157 |
+
chain_features["chain_class"] = "ligand"
|
| 1158 |
+
chain_features["all_atom_positions"] = chain_features["all_atom_positions"][None]
|
| 1159 |
+
chain_features["all_atom_mask"] = chain_features["all_atom_mask"][None]
|
| 1160 |
+
# Update Chain and Conf
|
| 1161 |
+
CHAIN_META_DATA["fake_ccds"].append(fake_ccd)
|
| 1162 |
+
sequence_3 = fake_ccd
|
| 1163 |
+
chain_id = f"SDFM_{sample_id}"
|
| 1164 |
+
CHAIN_META_DATA["chain_sequence_3s"][chain_id] = sequence_3
|
| 1165 |
+
CHAIN_META_DATA["chain_class"][chain_id] = "ligand"
|
| 1166 |
+
CONF_META_DATA = self._update_CONF_META_DATA_ligand(
|
| 1167 |
+
CONF_META_DATA, sequence_3, chain_features)
|
| 1168 |
+
all_chain_features[chain_id] = chain_features
|
| 1169 |
+
all_chain_features[chain_id] = self._update_chain_feature(
|
| 1170 |
+
chain_features,
|
| 1171 |
+
CONF_META_DATA
|
| 1172 |
+
)
|
| 1173 |
+
SEQ3 = {}
|
| 1174 |
+
CHAIN_CLASS = {}
|
| 1175 |
+
SEQ3[chain_id] = "-".join([fake_ccd])
|
| 1176 |
+
all_chain_features, ASYM_ID = self._add_assembly_feature(all_chain_features, SEQ3)
|
| 1177 |
+
all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = np.array(
|
| 1178 |
+
[len(chain_features["ref_atom_name_chars"])], dtype=np.int64)
|
| 1179 |
+
|
| 1180 |
+
all_chain_features[chain_id]["x_gt"] = chain_features["all_atom_positions"][0]
|
| 1181 |
+
all_chain_features[chain_id]["x_exists"] = chain_features["all_atom_mask"][0]
|
| 1182 |
+
CHAIN_CLASS[chain_id] = "ligand"
|
| 1183 |
+
infer_meta_data = {
|
| 1184 |
+
"CONF_META_DATA": CONF_META_DATA,
|
| 1185 |
+
"SEQ3": SEQ3,
|
| 1186 |
+
"ASYM_ID": ASYM_ID,
|
| 1187 |
+
"CHAIN_CLASS": CHAIN_CLASS
|
| 1188 |
+
}
|
| 1189 |
+
|
| 1190 |
+
# if not self.inference_mode:
|
| 1191 |
+
# all_chain_features, infer_meta_data = self.crop_all_chain_features(all_chain_features, infer_meta_data)
|
| 1192 |
+
|
| 1193 |
+
raw_feats, infer_meta_data = self.pair_and_merge(all_chain_features, infer_meta_data)
|
| 1194 |
+
|
| 1195 |
+
tensors = self.transform(raw_feats)
|
| 1196 |
+
infer_meta_data["system_id"] = sample_id
|
| 1197 |
+
return tensors, infer_meta_data
|
| 1198 |
+
|
| 1199 |
+
def random_load_mol_chunks(self, ligand_db_name="1"):
|
| 1200 |
+
if ligand_db_name == "0":
|
| 1201 |
+
ligand_db = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/try_new.pkl.gz")
|
| 1202 |
+
elif ligand_db_name == "1":
|
| 1203 |
+
id = random.randint(1, 374)
|
| 1204 |
+
ligand_db = load_pkl(f"/2022133002/data/ligand_samples/samples_{id}.pkl.gz")
|
| 1205 |
+
elif ligand_db_name == "2":
|
| 1206 |
+
ligand_db = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/try_400k_2_new.pkl.gz")
|
| 1207 |
+
else:
|
| 1208 |
+
raise ValueError("MOL DB Name is Wrong!")
|
| 1209 |
+
sample_id = random.choice(list(ligand_db.keys()))
|
| 1210 |
+
sample_feature = ligand_db[sample_id]
|
| 1211 |
+
tensors, infer_meta_data = self.load_ligand(sample_id, sample_feature)
|
| 1212 |
+
return tensors, infer_meta_data
|
| 1213 |
+
|
| 1214 |
+
def write_pdb(self, x_pred, fname, infer_meta_data):
|
| 1215 |
+
ccds = infer_meta_data["ccds"]
|
| 1216 |
+
atom_id_to_conformer_atom_id = infer_meta_data["atom_id_to_conformer_atom_id"]
|
| 1217 |
+
ccd_chunk_sizes = infer_meta_data["conformer_id_to_chunk_sizes"].tolist()
|
| 1218 |
+
CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"]
|
| 1219 |
+
conf_meta_data = infer_meta_data["CONF_META_DATA"]
|
| 1220 |
+
residue_index = infer_meta_data["residue_index"].tolist()
|
| 1221 |
+
asym_id = infer_meta_data["asym_id"].tolist()
|
| 1222 |
+
|
| 1223 |
+
atom_lines = []
|
| 1224 |
+
atom_offset = 0
|
| 1225 |
+
for ccd_id, (ccd, chunk_size, res_id) in enumerate(zip(ccds, ccd_chunk_sizes, residue_index)):
|
| 1226 |
+
inner_atom_idx = atom_id_to_conformer_atom_id[atom_offset:atom_offset + chunk_size]
|
| 1227 |
+
atom_names = [conf_meta_data[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
|
| 1228 |
+
atom_elements = [PeriodicTable[conf_meta_data[ccd]["ref_element"][i]] for i in inner_atom_idx]
|
| 1229 |
+
chain_tag = PDB_CHAIN_IDS[int(asym_id[ccd_id])]
|
| 1230 |
+
record_type = "HETATM" if CHAIN_CLASS[ccd_id] == "ligand" else "ATOM"
|
| 1231 |
+
|
| 1232 |
+
for ccd_atom_idx, atom_name in enumerate(atom_names):
|
| 1233 |
+
x = x_pred[atom_offset]
|
| 1234 |
+
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
| 1235 |
+
res_name_3 = ccd
|
| 1236 |
+
alt_loc = ""
|
| 1237 |
+
insertion_code = ""
|
| 1238 |
+
occupancy = 1.00
|
| 1239 |
+
element = atom_elements[ccd_atom_idx]
|
| 1240 |
+
# b_factor = torch.argmax(plddt[atom_offset],dim=-1).item()*2 +1
|
| 1241 |
+
b_factor = 70.
|
| 1242 |
+
charge = 0
|
| 1243 |
+
pos = x.tolist()
|
| 1244 |
+
atom_line = (
|
| 1245 |
+
f"{record_type:<6}{atom_offset + 1:>5} {name:<4}{alt_loc:>1}"
|
| 1246 |
+
f"{res_name_3.split()[0]:>3} {chain_tag:>1}"
|
| 1247 |
+
f"{res_id + 1:>4}{insertion_code:>1} "
|
| 1248 |
+
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
| 1249 |
+
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
| 1250 |
+
f"{element:>2}{charge:>2}"
|
| 1251 |
+
)
|
| 1252 |
+
atom_lines.append(atom_line)
|
| 1253 |
+
atom_offset += 1
|
| 1254 |
+
if atom_offset == len(atom_id_to_conformer_atom_id):
|
| 1255 |
+
break
|
| 1256 |
+
out = "\n".join(atom_lines)
|
| 1257 |
+
out = f"MODEL 1\n{out}\nTER\nENDMDL\nEND"
|
| 1258 |
+
dump_txt(out, fname)
|
PhysDock/data/generate_system.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from Bio.PDB import PDBParser
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
|
| 10 |
+
from PhysDock.utils.io_utils import dump_pkl, load_pkl, convert_md5_string
|
| 11 |
+
from PhysDock.data.constants.PDBData import protein_letters_3to1_extended
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings("ignore")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def generate_system(
|
| 17 |
+
receptor_pdb_path,
|
| 18 |
+
ligand_sdf_path,
|
| 19 |
+
ligand_ccd_id,
|
| 20 |
+
systems_dir,
|
| 21 |
+
ccd_id_meta_data=None,
|
| 22 |
+
|
| 23 |
+
# bfd_database_path,
|
| 24 |
+
# uniclust30_database_path,
|
| 25 |
+
# uniref90_database_path,
|
| 26 |
+
# mgnify_database_path,
|
| 27 |
+
# uniprot_database_path,
|
| 28 |
+
# jackhmmer_binary_path,
|
| 29 |
+
# hhblits_binary_path,
|
| 30 |
+
#
|
| 31 |
+
# input_dir,
|
| 32 |
+
# out_dir,
|
| 33 |
+
#
|
| 34 |
+
# n_cpus=16,
|
| 35 |
+
# n_workers=1,
|
| 36 |
+
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Parse PDB and SDF files to generate protein-ligand complex features.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
input_pdb (str): Path to the input PDB file.
|
| 43 |
+
input_sdf (str): Path to the input ligand SDF file.
|
| 44 |
+
systems_dir (str): Directory to save system feature pickle files.
|
| 45 |
+
feature_dir (str): Directory to save feature files (e.g., input FASTA).
|
| 46 |
+
ligand_id (str): CCD ID of the ligand.
|
| 47 |
+
"""
|
| 48 |
+
# Create output directories
|
| 49 |
+
if ccd_id_meta_data is None:
|
| 50 |
+
print("Loading CCD meta data ...")
|
| 51 |
+
ccd_id_meta_data = load_pkl(os.path.join(os.path.split(__file__)[0], "../../params/ccd_id_meta_data.pkl.gz"))
|
| 52 |
+
os.makedirs(systems_dir, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Initialize parser and data containers
|
| 56 |
+
pdb_parser = PDBParser()
|
| 57 |
+
structure = pdb_parser.get_structure("", receptor_pdb_path)
|
| 58 |
+
model = structure[0]
|
| 59 |
+
|
| 60 |
+
all_chain_features = {}
|
| 61 |
+
used_chain_ids = []
|
| 62 |
+
|
| 63 |
+
# Extract protein chains from PDB
|
| 64 |
+
for chain in model:
|
| 65 |
+
chain_id = chain.id
|
| 66 |
+
used_chain_ids.append(chain_id)
|
| 67 |
+
all_chain_features[chain_id] = {
|
| 68 |
+
"all_atom_positions": [],
|
| 69 |
+
"all_atom_mask": [],
|
| 70 |
+
"ccds": []
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
offset = None
|
| 74 |
+
for residue in chain:
|
| 75 |
+
if offset is None:
|
| 76 |
+
offset = int(residue.id[1])
|
| 77 |
+
|
| 78 |
+
resname = residue.get_resname().strip().ljust(3)
|
| 79 |
+
res_idx = int(residue.id[1]) - offset
|
| 80 |
+
num_atoms = len(ccd_id_meta_data[resname]["ref_atom_name_chars"])
|
| 81 |
+
|
| 82 |
+
# Fill missing residues
|
| 83 |
+
while len(all_chain_features[chain_id]["ccds"]) < res_idx:
|
| 84 |
+
all_chain_features[chain_id]["ccds"].append("UNK")
|
| 85 |
+
all_chain_features[chain_id]["all_atom_positions"].append(np.zeros([1, 3], dtype=np.float32))
|
| 86 |
+
all_chain_features[chain_id]["all_atom_mask"].append(np.zeros([1], dtype=np.int8))
|
| 87 |
+
|
| 88 |
+
# Initialize residue data
|
| 89 |
+
all_chain_features[chain_id]["ccds"].append(resname)
|
| 90 |
+
all_chain_features[chain_id]["all_atom_positions"].append(np.zeros([num_atoms, 3], dtype=np.float32))
|
| 91 |
+
all_chain_features[chain_id]["all_atom_mask"].append(np.zeros([num_atoms], dtype=np.int8))
|
| 92 |
+
|
| 93 |
+
ref_atom_names = ccd_id_meta_data[resname]["ref_atom_name_chars"]
|
| 94 |
+
for atom in residue:
|
| 95 |
+
if atom.name in ref_atom_names:
|
| 96 |
+
atom_idx = ref_atom_names.index(atom.name)
|
| 97 |
+
all_chain_features[chain_id]["all_atom_positions"][res_idx][atom_idx] = atom.coord
|
| 98 |
+
all_chain_features[chain_id]["all_atom_mask"][res_idx][atom_idx] = 1
|
| 99 |
+
|
| 100 |
+
# Add interaction features # TODO PLIP
|
| 101 |
+
interaction_keys = ['salt bridges', 'pi-cation interactions', 'hydrophobic interactions',
|
| 102 |
+
'pi-stacking', 'hydrogen bonds', 'metal complexes']
|
| 103 |
+
for key in interaction_keys:
|
| 104 |
+
all_chain_features[chain_id][key] = np.zeros(len(all_chain_features[chain_id]["ccds"]), dtype=np.int8)
|
| 105 |
+
|
| 106 |
+
# Extract ligand from SDF
|
| 107 |
+
supplier = Chem.SDMolSupplier(ligand_sdf_path, removeHs=True, sanitize=False)
|
| 108 |
+
mol = supplier[0]
|
| 109 |
+
mol = Chem.RemoveAllHs(mol)
|
| 110 |
+
conf = mol.GetConformer()
|
| 111 |
+
ligand_chain_id = "1"
|
| 112 |
+
used_chain_ids.append(ligand_chain_id)
|
| 113 |
+
|
| 114 |
+
ligand_atom_count = mol.GetNumAtoms()
|
| 115 |
+
ligand_positions = np.zeros([ligand_atom_count, 3], dtype=np.float32)
|
| 116 |
+
ligand_masks = np.ones([ligand_atom_count], dtype=np.int8)
|
| 117 |
+
|
| 118 |
+
for atom in mol.GetAtoms():
|
| 119 |
+
idx = atom.GetIdx()
|
| 120 |
+
pos = conf.GetAtomPosition(idx)
|
| 121 |
+
ligand_positions[idx] = [pos.x, pos.y, pos.z]
|
| 122 |
+
|
| 123 |
+
all_chain_features[ligand_chain_id] = {
|
| 124 |
+
"all_atom_positions": [ligand_positions],
|
| 125 |
+
"all_atom_mask": [ligand_masks],
|
| 126 |
+
"ccds": [ligand_ccd_id.upper()]
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for key in interaction_keys:
|
| 130 |
+
all_chain_features[ligand_chain_id][key] = np.zeros(1, dtype=np.int8)
|
| 131 |
+
|
| 132 |
+
# Generate system pickle file
|
| 133 |
+
save_name = os.path.basename(receptor_pdb_path).replace('.pdb', '')
|
| 134 |
+
for cid in used_chain_ids:
|
| 135 |
+
save_name += f"_{cid}"
|
| 136 |
+
|
| 137 |
+
dump_pkl(all_chain_features, os.path.join(systems_dir, f"{save_name}.pkl.gz"))
|
| 138 |
+
|
| 139 |
+
# Generate FASTA files to run homo search
|
| 140 |
+
for cid, features in all_chain_features.items():
|
| 141 |
+
if cid == ligand_chain_id:
|
| 142 |
+
continue
|
| 143 |
+
sequence = ''.join(protein_letters_3to1_extended.get(ccd, "X") for ccd in features["ccds"])
|
| 144 |
+
md5_hash = convert_md5_string(f"protein:{sequence}")
|
| 145 |
+
os.makedirs(os.path.join(systems_dir, "fastas"), exist_ok=True)
|
| 146 |
+
with open(os.path.join(systems_dir, "fastas", f"{md5_hash}.fasta"), "w") as f:
|
| 147 |
+
f.write(f">{md5_hash}\n{sequence}\n")
|
| 148 |
+
print("Make system successfully!")
|
PhysDock/data/relaxation.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import tqdm
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
sys.path.append("../")
|
| 10 |
+
from PhysDock.utils.io_utils import run_pool_tasks, load_txt, dump_txt
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import openmm.app as mm_app
|
| 13 |
+
import openmm.unit as mm_unit
|
| 14 |
+
import openmm as mm
|
| 15 |
+
import os.path
|
| 16 |
+
import sys
|
| 17 |
+
import mdtraj
|
| 18 |
+
from openmm.app import PDBFile, Modeller
|
| 19 |
+
import pdbfixer
|
| 20 |
+
from openmmforcefields.generators import SystemGenerator
|
| 21 |
+
from openff.toolkit import Molecule
|
| 22 |
+
from openff.toolkit.utils.exceptions import UndefinedStereochemistryError, RadicalsNotSupportedError
|
| 23 |
+
from openmm import CustomExternalForce
|
| 24 |
+
from posebusters import PoseBusters
|
| 25 |
+
from posebusters.posebusters import _dataframe_from_output
|
| 26 |
+
from posebusters.cli import _select_mode, _format_results
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_bust_results( # noqa: PLR0913
|
| 30 |
+
mol_pred,
|
| 31 |
+
mol_true,
|
| 32 |
+
mol_cond,
|
| 33 |
+
top_n: int | None = None,
|
| 34 |
+
):
|
| 35 |
+
mol_pred = [Path(mol_pred)]
|
| 36 |
+
mol_true = Path(mol_true)
|
| 37 |
+
mol_cond = Path(mol_cond) # Each bust running has different receptor
|
| 38 |
+
|
| 39 |
+
# run on single input
|
| 40 |
+
d = {k for k, v in dict(mol_pred=mol_pred, mol_true=mol_true, mol_cond=mol_cond).items() if v}
|
| 41 |
+
mode = _select_mode(None, d)
|
| 42 |
+
posebusters = PoseBusters(mode, top_n=top_n)
|
| 43 |
+
cols = ["mol_pred", "mol_true", "mol_cond"]
|
| 44 |
+
posebusters.file_paths = pd.DataFrame([[mol_pred, mol_true, mol_cond] for mol_pred in mol_pred], columns=cols)
|
| 45 |
+
posebusters_results = posebusters._run()
|
| 46 |
+
results = None
|
| 47 |
+
for i, results_dict in enumerate(posebusters_results):
|
| 48 |
+
results = _dataframe_from_output(results_dict, posebusters.config, full_report=True)
|
| 49 |
+
break
|
| 50 |
+
return results
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def fix_pdb(pdbname, outdir, file_name):
|
| 54 |
+
"""add"""
|
| 55 |
+
fixer = pdbfixer.PDBFixer(pdbname)
|
| 56 |
+
fixer.findMissingResidues()
|
| 57 |
+
fixer.findNonstandardResidues()
|
| 58 |
+
fixer.replaceNonstandardResidues()
|
| 59 |
+
fixer.findMissingAtoms()
|
| 60 |
+
fixer.addMissingAtoms()
|
| 61 |
+
fixer.addMissingHydrogens(7.0)
|
| 62 |
+
# 根据文件名判断是否写入指定目录
|
| 63 |
+
# if "relaxed_complex" in file_name:
|
| 64 |
+
# target_path = f'{outdir}/{file_name}_hydrogen_added.pdb'
|
| 65 |
+
# else:
|
| 66 |
+
# target_path = f'{file_name}_hydrogen_added.pdb'
|
| 67 |
+
# mm_app.PDBFile.writeFile(fixer.topology, fixer.positions, open(target_path, 'w'))
|
| 68 |
+
return fixer.topology, fixer.positions
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def set_system(topology):
|
| 72 |
+
"""
|
| 73 |
+
Set the system using the topology from the pdb file
|
| 74 |
+
"""
|
| 75 |
+
# Put it in a force field to skip adding all particles manually
|
| 76 |
+
forcefield = mm_app.ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
|
| 77 |
+
|
| 78 |
+
system = forcefield.createSystem(topology,
|
| 79 |
+
removeCMMotion=False,
|
| 80 |
+
nonbondedMethod=mm_app.NoCutoff,
|
| 81 |
+
rigidWater=True # Use implicit solvent
|
| 82 |
+
)
|
| 83 |
+
return system
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def minimize_energy(
|
| 87 |
+
topology,
|
| 88 |
+
system,
|
| 89 |
+
positions,
|
| 90 |
+
outdir,
|
| 91 |
+
out_title
|
| 92 |
+
):
|
| 93 |
+
'''Function that minimizes energy, given topology, OpenMM system, and positions '''
|
| 94 |
+
# Use a Brownian Integrator
|
| 95 |
+
integrator = mm.BrownianIntegrator(
|
| 96 |
+
100 * mm.unit.kelvin,
|
| 97 |
+
100. / mm.unit.picoseconds,
|
| 98 |
+
2.0 * mm.unit.femtoseconds
|
| 99 |
+
)
|
| 100 |
+
# platform = Platform.getPlatformByName('CUDA')
|
| 101 |
+
# properties = {'DeviceIndex': '0', 'Precision': 'mixed'}
|
| 102 |
+
simulation = mm.app.Simulation(topology, system, integrator)
|
| 103 |
+
|
| 104 |
+
# Initialize the DCDReporter
|
| 105 |
+
reportInterval = 100 # Adjust this value as needed
|
| 106 |
+
reporter = mdtraj.reporters.DCDReporter('positions.dcd', reportInterval)
|
| 107 |
+
|
| 108 |
+
# Add the reporter to the simulation
|
| 109 |
+
simulation.reporters.append(reporter)
|
| 110 |
+
|
| 111 |
+
simulation.context.setPositions(positions)
|
| 112 |
+
|
| 113 |
+
simulation.minimizeEnergy(1, 100)
|
| 114 |
+
# Save positions
|
| 115 |
+
minpositions = simulation.context.getState(getPositions=True).getPositions()
|
| 116 |
+
|
| 117 |
+
# 根据out_title决定是否写入指定目录
|
| 118 |
+
if "relaxed_complex" in out_title:
|
| 119 |
+
target_path = outdir + f'/{out_title}.pdb'
|
| 120 |
+
else:
|
| 121 |
+
target_path = f'{out_title}.pdb'
|
| 122 |
+
mm_app.PDBFile.writeFile(topology, minpositions, open(target_path, 'w'))
|
| 123 |
+
|
| 124 |
+
# Get and return the minimized energy
|
| 125 |
+
minimized_energy = simulation.context.getState(getEnergy=True).getPotentialEnergy()
|
| 126 |
+
|
| 127 |
+
reporter.close()
|
| 128 |
+
|
| 129 |
+
return topology, minpositions, minimized_energy
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def add_restraints(
|
| 133 |
+
system,
|
| 134 |
+
topology,
|
| 135 |
+
positions,
|
| 136 |
+
restraint_type
|
| 137 |
+
):
|
| 138 |
+
'''Function to add restraints to specified group of atoms
|
| 139 |
+
|
| 140 |
+
Code adapted from https://gist.github.com/peastman/ad8cda653242d731d75e18c836b2a3a5
|
| 141 |
+
|
| 142 |
+
'''
|
| 143 |
+
restraint = CustomExternalForce('k*periodicdistance(x, y, z, x0, y0, z0)^2')
|
| 144 |
+
system.addForce(restraint)
|
| 145 |
+
restraint.addGlobalParameter('k', 100000000.0 * mm_unit.kilojoules_per_mole / mm_unit.nanometer ** 2)
|
| 146 |
+
restraint.addPerParticleParameter('x0')
|
| 147 |
+
restraint.addPerParticleParameter('y0')
|
| 148 |
+
restraint.addPerParticleParameter('z0')
|
| 149 |
+
|
| 150 |
+
for atom in topology.atoms():
|
| 151 |
+
if restraint_type == 'protein':
|
| 152 |
+
if 'x' not in atom.name:
|
| 153 |
+
restraint.addParticle(atom.index, positions[atom.index])
|
| 154 |
+
elif restraint_type == 'CA+ligand':
|
| 155 |
+
if ('x' in atom.name) or (atom.name == "CA"):
|
| 156 |
+
restraint.addParticle(atom.index, positions[atom.index])
|
| 157 |
+
|
| 158 |
+
return system
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def run(
|
| 162 |
+
# i
|
| 163 |
+
input_pdb,
|
| 164 |
+
outdir,
|
| 165 |
+
mol_in,
|
| 166 |
+
file_name,
|
| 167 |
+
restraint_type="ca+ligand",
|
| 168 |
+
relax_protein_first=False,
|
| 169 |
+
steps=100,
|
| 170 |
+
):
|
| 171 |
+
try:
|
| 172 |
+
ligand_mol = Molecule.from_file(mol_in)
|
| 173 |
+
# Check for undefined stereochemistry, allow undefined stereochemistry to be loaded
|
| 174 |
+
except UndefinedStereochemistryError:
|
| 175 |
+
print('Undefined Stereochemistry Error found! Trying with undefined stereo flag True')
|
| 176 |
+
ligand_mol = Molecule.from_file(mol_in, allow_undefined_stereo=True)
|
| 177 |
+
# Check for radicals -- break out of script if radical is encountered
|
| 178 |
+
except RadicalsNotSupportedError:
|
| 179 |
+
print('OpenFF does not currently support radicals -- use unrelaxed structure')
|
| 180 |
+
sys.exit()
|
| 181 |
+
# Assigning partial charges first because the default method (am1bcc) does not work
|
| 182 |
+
ligand_mol.assign_partial_charges(partial_charge_method='gasteiger')
|
| 183 |
+
|
| 184 |
+
## Read protein PDB and add hydrogens
|
| 185 |
+
protein_topology, protein_positions = fix_pdb(input_pdb, outdir, file_name)
|
| 186 |
+
# print('Added all atoms...')
|
| 187 |
+
|
| 188 |
+
# Minimize energy for the protein
|
| 189 |
+
system = set_system(protein_topology)
|
| 190 |
+
# print('Creating system...')
|
| 191 |
+
# Relax
|
| 192 |
+
if relax_protein_first:
|
| 193 |
+
print('Relaxing ONLY protein structure...')
|
| 194 |
+
protein_topology, protein_positions = minimize_energy(
|
| 195 |
+
protein_topology,
|
| 196 |
+
system,
|
| 197 |
+
protein_positions,
|
| 198 |
+
outdir,
|
| 199 |
+
f'{file_name}_relaxed_protein'
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# print('Preparing complex')
|
| 203 |
+
## Add protein first
|
| 204 |
+
modeller = Modeller(protein_topology, protein_positions)
|
| 205 |
+
# print('System has %d atoms' % modeller.topology.getNumAtoms())
|
| 206 |
+
|
| 207 |
+
## Then add ligand
|
| 208 |
+
# print('Adding ligand...')
|
| 209 |
+
lig_top = ligand_mol.to_topology()
|
| 210 |
+
modeller.add(lig_top.to_openmm(), lig_top.get_positions().to_openmm())
|
| 211 |
+
# print('System has %d atoms' % modeller.topology.getNumAtoms())
|
| 212 |
+
|
| 213 |
+
# print('Preparing system')
|
| 214 |
+
# Initialize a SystemGenerator using the GAFF for the ligand and implicit water.
|
| 215 |
+
# forcefield_kwargs = {'constraints': mm_app.HBonds, 'rigidWater': True, 'removeCMMotion': False, 'hydrogenMass': 4*mm_unit.amu }
|
| 216 |
+
system_generator = SystemGenerator(
|
| 217 |
+
forcefields=['amber14-all.xml', 'implicit/gbn2.xml'],
|
| 218 |
+
small_molecule_forcefield='gaff-2.11',
|
| 219 |
+
molecules=[ligand_mol],
|
| 220 |
+
# forcefield_kwargs=forcefield_kwargs
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
## Create system
|
| 224 |
+
system = system_generator.create_system(modeller.topology, molecules=ligand_mol)
|
| 225 |
+
|
| 226 |
+
# if restraint_type == 'protein':
|
| 227 |
+
# print('Adding restraints on entire protein')
|
| 228 |
+
# elif restraint_type == 'CA+ligand':
|
| 229 |
+
# print('Adding restraints on protein CAs and ligand atoms')
|
| 230 |
+
|
| 231 |
+
system = add_restraints(system, modeller.topology, modeller.positions, restraint_type=restraint_type)
|
| 232 |
+
|
| 233 |
+
## Minimize energy for the complex and print the minimized energy
|
| 234 |
+
_, _, minimized_energy = minimize_energy(
|
| 235 |
+
modeller.topology,
|
| 236 |
+
system,
|
| 237 |
+
modeller.positions,
|
| 238 |
+
outdir,
|
| 239 |
+
f'{file_name}_relaxed_complex'
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def relax(receptor_pdb, ligand_mol_sdf):
|
| 244 |
+
output_dir = os.path.split(receptor_pdb)[0]
|
| 245 |
+
file_name = os.path.split(receptor_pdb)[1].split(".")[0]
|
| 246 |
+
system_file_name = "system" + file_name.split("receptor")[1]
|
| 247 |
+
try:
|
| 248 |
+
run(
|
| 249 |
+
input_pdb=receptor_pdb,
|
| 250 |
+
outdir=output_dir,
|
| 251 |
+
mol_in=ligand_mol_sdf,
|
| 252 |
+
file_name=system_file_name
|
| 253 |
+
)
|
| 254 |
+
lines = load_txt(
|
| 255 |
+
os.path.join(output_dir, f"{system_file_name}_relaxed_complex.pdb")).split("\n")
|
| 256 |
+
receptor = "\n".join([i for i in lines if "HETATM" not in i])
|
| 257 |
+
dump_txt(receptor, os.path.join(output_dir, f"{file_name}_relaxed_complex.pdb"))
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(dir, "can't relax,", e)
|
PhysDock/data/tools/PDBData.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2000 Andrew Dalke. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This file is part of the Biopython distribution and governed by your
|
| 4 |
+
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
|
| 5 |
+
# Please see the LICENSE file that should have been included as part of this
|
| 6 |
+
# package.
|
| 7 |
+
"""Information about the IUPAC alphabets."""
|
| 8 |
+
|
| 9 |
+
protein_letters = "ACDEFGHIKLMNPQRSTVWY"
|
| 10 |
+
extended_protein_letters = "ACDEFGHIKLMNPQRSTVWYBXZJUO"
|
| 11 |
+
|
| 12 |
+
# B = "Asx"; aspartic acid or asparagine (D or N)
|
| 13 |
+
# X = "Xxx"; unknown or 'other' amino acid
|
| 14 |
+
# Z = "Glx"; glutamic acid or glutamine (E or Q)
|
| 15 |
+
# http://www.chem.qmul.ac.uk/iupac/AminoAcid/A2021.html#AA212
|
| 16 |
+
#
|
| 17 |
+
# J = "Xle"; leucine or isoleucine (L or I, used in NMR)
|
| 18 |
+
# Mentioned in http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
|
| 19 |
+
# Also the International Nucleotide Sequence Database Collaboration (INSDC)
|
| 20 |
+
# (i.e. GenBank, EMBL, DDBJ) adopted this in 2006
|
| 21 |
+
# http://www.ddbj.nig.ac.jp/insdc/icm2006-e.html
|
| 22 |
+
#
|
| 23 |
+
# Xle (J); Leucine or Isoleucine
|
| 24 |
+
# The residue abbreviations, Xle (the three-letter abbreviation) and J
|
| 25 |
+
# (the one-letter abbreviation) are reserved for the case that cannot
|
| 26 |
+
# experimentally distinguish leucine from isoleucine.
|
| 27 |
+
#
|
| 28 |
+
# U = "Sec"; selenocysteine
|
| 29 |
+
# http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
|
| 30 |
+
#
|
| 31 |
+
# O = "Pyl"; pyrrolysine
|
| 32 |
+
# http://www.chem.qmul.ac.uk/iubmb/newsletter/2009.html#item35
|
| 33 |
+
|
| 34 |
+
protein_letters_1to3 = {
|
| 35 |
+
"A": "Ala",
|
| 36 |
+
"C": "Cys",
|
| 37 |
+
"D": "Asp",
|
| 38 |
+
"E": "Glu",
|
| 39 |
+
"F": "Phe",
|
| 40 |
+
"G": "Gly",
|
| 41 |
+
"H": "His",
|
| 42 |
+
"I": "Ile",
|
| 43 |
+
"K": "Lys",
|
| 44 |
+
"L": "Leu",
|
| 45 |
+
"M": "Met",
|
| 46 |
+
"N": "Asn",
|
| 47 |
+
"P": "Pro",
|
| 48 |
+
"Q": "Gln",
|
| 49 |
+
"R": "Arg",
|
| 50 |
+
"S": "Ser",
|
| 51 |
+
"T": "Thr",
|
| 52 |
+
"V": "Val",
|
| 53 |
+
"W": "Trp",
|
| 54 |
+
"Y": "Tyr",
|
| 55 |
+
}
|
| 56 |
+
protein_letters_1to3 = {k.upper(): v.upper() for k, v in protein_letters_1to3.items()}
|
| 57 |
+
protein_letters_3to1 = {v: k for k, v in protein_letters_1to3.items()}
|
| 58 |
+
|
| 59 |
+
protein_letters_3to1_extended = {
|
| 60 |
+
"A5N": "N", "A8E": "V", "A9D": "S", "AA3": "A", "AA4": "A", "AAR": "R",
|
| 61 |
+
"ABA": "A", "ACL": "R", "AEA": "C", "AEI": "D", "AFA": "N", "AGM": "R",
|
| 62 |
+
"AGQ": "Y", "AGT": "C", "AHB": "N", "AHL": "R", "AHO": "A", "AHP": "A",
|
| 63 |
+
"AIB": "A", "AKL": "D", "AKZ": "D", "ALA": "A", "ALC": "A", "ALM": "A",
|
| 64 |
+
"ALN": "A", "ALO": "T", "ALS": "A", "ALT": "A", "ALV": "A", "ALY": "K",
|
| 65 |
+
"AME": "M", "AN6": "L", "AN8": "A", "API": "K", "APK": "K", "AR2": "R",
|
| 66 |
+
"AR4": "E", "AR7": "R", "ARG": "R", "ARM": "R", "ARO": "R", "AS7": "N",
|
| 67 |
+
"ASA": "D", "ASB": "D", "ASI": "D", "ASK": "D", "ASL": "D", "ASN": "N",
|
| 68 |
+
"ASP": "D", "ASQ": "D", "AYA": "A", "AZH": "A", "AZK": "K", "AZS": "S",
|
| 69 |
+
"AZY": "Y", "AVJ": "H", "A30": "Y", "A3U": "F", "ECC": "Q", "ECX": "C",
|
| 70 |
+
"EFC": "C", "EHP": "F", "ELY": "K", "EME": "E", "EPM": "M", "EPQ": "Q",
|
| 71 |
+
"ESB": "Y", "ESC": "M", "EXY": "L", "EXA": "K", "E0Y": "P", "E9V": "H",
|
| 72 |
+
"E9M": "W", "EJA": "C", "EUP": "T", "EZY": "G", "E9C": "Y", "EW6": "S",
|
| 73 |
+
"EXL": "W", "I2M": "I", "I4G": "G", "I58": "K", "IAM": "A", "IAR": "R",
|
| 74 |
+
"ICY": "C", "IEL": "K", "IGL": "G", "IIL": "I", "ILE": "I", "ILG": "E",
|
| 75 |
+
"ILM": "I", "ILX": "I", "ILY": "K", "IML": "I", "IOR": "R", "IPG": "G",
|
| 76 |
+
"IT1": "K", "IYR": "Y", "IZO": "M", "IC0": "G", "M0H": "C", "M2L": "K",
|
| 77 |
+
"M2S": "M", "M30": "G", "M3L": "K", "M3R": "K", "MA ": "A", "MAA": "A",
|
| 78 |
+
"MAI": "R", "MBQ": "Y", "MC1": "S", "MCL": "K", "MCS": "C", "MD3": "C",
|
| 79 |
+
"MD5": "C", "MD6": "G", "MDF": "Y", "ME0": "M", "MEA": "F", "MEG": "E",
|
| 80 |
+
"MEN": "N", "MEQ": "Q", "MET": "M", "MEU": "G", "MFN": "E", "MGG": "R",
|
| 81 |
+
"MGN": "Q", "MGY": "G", "MH1": "H", "MH6": "S", "MHL": "L", "MHO": "M",
|
| 82 |
+
"MHS": "H", "MHU": "F", "MIR": "S", "MIS": "S", "MK8": "L", "ML3": "K",
|
| 83 |
+
"MLE": "L", "MLL": "L", "MLY": "K", "MLZ": "K", "MME": "M", "MMO": "R",
|
| 84 |
+
"MNL": "L", "MNV": "V", "MP8": "P", "MPQ": "G", "MSA": "G", "MSE": "M",
|
| 85 |
+
"MSL": "M", "MSO": "M", "MT2": "M", "MTY": "Y", "MVA": "V", "MYK": "K",
|
| 86 |
+
"MYN": "R", "QCS": "C", "QIL": "I", "QMM": "Q", "QPA": "C", "QPH": "F",
|
| 87 |
+
"Q3P": "K", "QVA": "C", "QX7": "A", "Q2E": "W", "Q75": "M", "Q78": "F",
|
| 88 |
+
"QM8": "L", "QMB": "A", "QNQ": "C", "QNT": "C", "QNW": "C", "QO2": "C",
|
| 89 |
+
"QO5": "C", "QO8": "C", "QQ8": "Q", "U2X": "Y", "U3X": "F", "UF0": "S",
|
| 90 |
+
"UGY": "G", "UM1": "A", "UM2": "A", "UMA": "A", "UQK": "A", "UX8": "W",
|
| 91 |
+
"UXQ": "F", "YCM": "C", "YOF": "Y", "YPR": "P", "YPZ": "Y", "YTH": "T",
|
| 92 |
+
"Y1V": "L", "Y57": "K", "YHA": "K", "200": "F", "23F": "F", "23P": "A",
|
| 93 |
+
"26B": "T", "28X": "T", "2AG": "A", "2CO": "C", "2FM": "M", "2GX": "F",
|
| 94 |
+
"2HF": "H", "2JG": "S", "2KK": "K", "2KP": "K", "2LT": "Y", "2LU": "L",
|
| 95 |
+
"2ML": "L", "2MR": "R", "2MT": "P", "2OR": "R", "2P0": "P", "2QZ": "T",
|
| 96 |
+
"2R3": "Y", "2RA": "A", "2RX": "S", "2SO": "H", "2TY": "Y", "2VA": "V",
|
| 97 |
+
"2XA": "C", "2ZC": "S", "6CL": "K", "6CW": "W", "6GL": "A", "6HN": "K",
|
| 98 |
+
"60F": "C", "66D": "I", "6CV": "A", "6M6": "C", "6V1": "C", "6WK": "C",
|
| 99 |
+
"6Y9": "P", "6DN": "K", "DA2": "R", "DAB": "A", "DAH": "F", "DBS": "S",
|
| 100 |
+
"DBU": "T", "DBY": "Y", "DBZ": "A", "DC2": "C", "DDE": "H", "DDZ": "A",
|
| 101 |
+
"DI7": "Y", "DHA": "S", "DHN": "V", "DIR": "R", "DLS": "K", "DM0": "K",
|
| 102 |
+
"DMH": "N", "DMK": "D", "DNL": "K", "DNP": "A", "DNS": "K", "DNW": "A",
|
| 103 |
+
"DOH": "D", "DON": "L", "DP1": "R", "DPL": "P", "DPP": "A", "DPQ": "Y",
|
| 104 |
+
"DYS": "C", "D2T": "D", "DYA": "D", "DJD": "F", "DYJ": "P", "DV9": "E",
|
| 105 |
+
"H14": "F", "H1D": "M", "H5M": "P", "HAC": "A", "HAR": "R", "HBN": "H",
|
| 106 |
+
"HCM": "C", "HGY": "G", "HHI": "H", "HIA": "H", "HIC": "H", "HIP": "H",
|
| 107 |
+
"HIQ": "H", "HIS": "H", "HL2": "L", "HLU": "L", "HMR": "R", "HNC": "C",
|
| 108 |
+
"HOX": "F", "HPC": "F", "HPE": "F", "HPH": "F", "HPQ": "F", "HQA": "A",
|
| 109 |
+
"HR7": "R", "HRG": "R", "HRP": "W", "HS8": "H", "HS9": "H", "HSE": "S",
|
| 110 |
+
"HSK": "H", "HSL": "S", "HSO": "H", "HT7": "W", "HTI": "C", "HTR": "W",
|
| 111 |
+
"HV5": "A", "HVA": "V", "HY3": "P", "HYI": "M", "HYP": "P", "HZP": "P",
|
| 112 |
+
"HIX": "A", "HSV": "H", "HLY": "K", "HOO": "H", "H7V": "A", "L5P": "K",
|
| 113 |
+
"LRK": "K", "L3O": "L", "LA2": "K", "LAA": "D", "LAL": "A", "LBY": "K",
|
| 114 |
+
"LCK": "K", "LCX": "K", "LDH": "K", "LE1": "V", "LED": "L", "LEF": "L",
|
| 115 |
+
"LEH": "L", "LEM": "L", "LEN": "L", "LET": "K", "LEU": "L", "LEX": "L",
|
| 116 |
+
"LGY": "K", "LLO": "K", "LLP": "K", "LLY": "K", "LLZ": "K", "LME": "E",
|
| 117 |
+
"LMF": "K", "LMQ": "Q", "LNE": "L", "LNM": "L", "LP6": "K", "LPD": "P",
|
| 118 |
+
"LPG": "G", "LPS": "S", "LSO": "K", "LTR": "W", "LVG": "G", "LVN": "V",
|
| 119 |
+
"LWY": "P", "LYF": "K", "LYK": "K", "LYM": "K", "LYN": "K", "LYO": "K",
|
| 120 |
+
"LYP": "K", "LYR": "K", "LYS": "K", "LYU": "K", "LYX": "K", "LYZ": "K",
|
| 121 |
+
"LAY": "L", "LWI": "F", "LBZ": "K", "P1L": "C", "P2Q": "Y", "P2Y": "P",
|
| 122 |
+
"P3Q": "Y", "PAQ": "Y", "PAS": "D", "PAT": "W", "PBB": "C", "PBF": "F",
|
| 123 |
+
"PCA": "Q", "PCC": "P", "PCS": "F", "PE1": "K", "PEC": "C", "PF5": "F",
|
| 124 |
+
"PFF": "F", "PG1": "S", "PGY": "G", "PHA": "F", "PHD": "D", "PHE": "F",
|
| 125 |
+
"PHI": "F", "PHL": "F", "PHM": "F", "PKR": "P", "PLJ": "P", "PM3": "F",
|
| 126 |
+
"POM": "P", "PPN": "F", "PR3": "C", "PR4": "P", "PR7": "P", "PR9": "P",
|
| 127 |
+
"PRJ": "P", "PRK": "K", "PRO": "P", "PRS": "P", "PRV": "G", "PSA": "F",
|
| 128 |
+
"PSH": "H", "PTH": "Y", "PTM": "Y", "PTR": "Y", "PVH": "H", "PXU": "P",
|
| 129 |
+
"PYA": "A", "PYH": "K", "PYX": "C", "PH6": "P", "P9S": "C", "P5U": "S",
|
| 130 |
+
"POK": "R", "T0I": "Y", "T11": "F", "TAV": "D", "TBG": "V", "TBM": "T",
|
| 131 |
+
"TCQ": "Y", "TCR": "W", "TEF": "F", "TFQ": "F", "TH5": "T", "TH6": "T",
|
| 132 |
+
"THC": "T", "THR": "T", "THZ": "R", "TIH": "A", "TIS": "S", "TLY": "K",
|
| 133 |
+
"TMB": "T", "TMD": "T", "TNB": "C", "TNR": "S", "TNY": "T", "TOQ": "W",
|
| 134 |
+
"TOX": "W", "TPJ": "P", "TPK": "P", "TPL": "W", "TPO": "T", "TPQ": "Y",
|
| 135 |
+
"TQI": "W", "TQQ": "W", "TQZ": "C", "TRF": "W", "TRG": "K", "TRN": "W",
|
| 136 |
+
"TRO": "W", "TRP": "W", "TRQ": "W", "TRW": "W", "TRX": "W", "TRY": "W",
|
| 137 |
+
"TS9": "I", "TSY": "C", "TTQ": "W", "TTS": "Y", "TXY": "Y", "TY1": "Y",
|
| 138 |
+
"TY2": "Y", "TY3": "Y", "TY5": "Y", "TY8": "Y", "TY9": "Y", "TYB": "Y",
|
| 139 |
+
"TYC": "Y", "TYE": "Y", "TYI": "Y", "TYJ": "Y", "TYN": "Y", "TYO": "Y",
|
| 140 |
+
"TYQ": "Y", "TYR": "Y", "TYS": "Y", "TYT": "Y", "TYW": "Y", "TYY": "Y",
|
| 141 |
+
"T8L": "T", "T9E": "T", "TNQ": "W", "TSQ": "F", "TGH": "W", "X2W": "E",
|
| 142 |
+
"XCN": "C", "XPR": "P", "XSN": "N", "XW1": "A", "XX1": "K", "XYC": "A",
|
| 143 |
+
"XA6": "F", "11Q": "P", "11W": "E", "12L": "P", "12X": "P", "12Y": "P",
|
| 144 |
+
"143": "C", "1AC": "A", "1L1": "A", "1OP": "Y", "1PA": "F", "1PI": "A",
|
| 145 |
+
"1TQ": "W", "1TY": "Y", "1X6": "S", "56A": "H", "5AB": "A", "5CS": "C",
|
| 146 |
+
"5CW": "W", "5HP": "E", "5OH": "A", "5PG": "G", "51T": "Y", "54C": "W",
|
| 147 |
+
"5CR": "F", "5CT": "K", "5FQ": "A", "5GM": "I", "5JP": "S", "5T3": "K",
|
| 148 |
+
"5MW": "K", "5OW": "K", "5R5": "S", "5VV": "N", "5XU": "A", "55I": "F",
|
| 149 |
+
"999": "D", "9DN": "N", "9NE": "E", "9NF": "F", "9NR": "R", "9NV": "V",
|
| 150 |
+
"9E7": "K", "9KP": "K", "9WV": "A", "9TR": "K", "9TU": "K", "9TX": "K",
|
| 151 |
+
"9U0": "K", "9IJ": "F", "B1F": "F", "B27": "T", "B2A": "A", "B2F": "F",
|
| 152 |
+
"B2I": "I", "B2V": "V", "B3A": "A", "B3D": "D", "B3E": "E", "B3K": "K",
|
| 153 |
+
"B3U": "H", "B3X": "N", "B3Y": "Y", "BB6": "C", "BB7": "C", "BB8": "F",
|
| 154 |
+
"BB9": "C", "BBC": "C", "BCS": "C", "BCX": "C", "BFD": "D", "BG1": "S",
|
| 155 |
+
"BH2": "D", "BHD": "D", "BIF": "F", "BIU": "I", "BL2": "L", "BLE": "L",
|
| 156 |
+
"BLY": "K", "BMT": "T", "BNN": "F", "BOR": "R", "BP5": "A", "BPE": "C",
|
| 157 |
+
"BSE": "S", "BTA": "L", "BTC": "C", "BTK": "K", "BTR": "W", "BUC": "C",
|
| 158 |
+
"BUG": "V", "BYR": "Y", "BWV": "R", "BWB": "S", "BXT": "S", "F2F": "F",
|
| 159 |
+
"F2Y": "Y", "FAK": "K", "FB5": "A", "FB6": "A", "FC0": "F", "FCL": "F",
|
| 160 |
+
"FDL": "K", "FFM": "C", "FGL": "G", "FGP": "S", "FH7": "K", "FHL": "K",
|
| 161 |
+
"FHO": "K", "FIO": "R", "FLA": "A", "FLE": "L", "FLT": "Y", "FME": "M",
|
| 162 |
+
"FOE": "C", "FP9": "P", "FPK": "P", "FT6": "W", "FTR": "W", "FTY": "Y",
|
| 163 |
+
"FVA": "V", "FZN": "K", "FY3": "Y", "F7W": "W", "FY2": "Y", "FQA": "K",
|
| 164 |
+
"F7Q": "Y", "FF9": "K", "FL6": "D", "JJJ": "C", "JJK": "C", "JJL": "C",
|
| 165 |
+
"JLP": "K", "J3D": "C", "J9Y": "R", "J8W": "S", "JKH": "P", "N10": "S",
|
| 166 |
+
"N7P": "P", "NA8": "A", "NAL": "A", "NAM": "A", "NBQ": "Y", "NC1": "S",
|
| 167 |
+
"NCB": "A", "NEM": "H", "NEP": "H", "NFA": "F", "NIY": "Y", "NLB": "L",
|
| 168 |
+
"NLE": "L", "NLN": "L", "NLO": "L", "NLP": "L", "NLQ": "Q", "NLY": "G",
|
| 169 |
+
"NMC": "G", "NMM": "R", "NNH": "R", "NOT": "L", "NPH": "C", "NPI": "A",
|
| 170 |
+
"NTR": "Y", "NTY": "Y", "NVA": "V", "NWD": "A", "NYB": "C", "NYS": "C",
|
| 171 |
+
"NZH": "H", "N80": "P", "NZC": "T", "NLW": "L", "N0A": "F", "N9P": "A",
|
| 172 |
+
"N65": "K", "R1A": "C", "R4K": "W", "RE0": "W", "RE3": "W", "RGL": "R",
|
| 173 |
+
"RGP": "E", "RT0": "P", "RVX": "S", "RZ4": "S", "RPI": "R", "RVJ": "A",
|
| 174 |
+
"VAD": "V", "VAF": "V", "VAH": "V", "VAI": "V", "VAL": "V", "VB1": "K",
|
| 175 |
+
"VH0": "P", "VR0": "R", "V44": "C", "V61": "F", "VPV": "K", "V5N": "H",
|
| 176 |
+
"V7T": "K", "Z01": "A", "Z3E": "T", "Z70": "H", "ZBZ": "C", "ZCL": "F",
|
| 177 |
+
"ZU0": "T", "ZYJ": "P", "ZYK": "P", "ZZD": "C", "ZZJ": "A", "ZIQ": "W",
|
| 178 |
+
"ZPO": "P", "ZDJ": "Y", "ZT1": "K", "30V": "C", "31Q": "C", "33S": "F",
|
| 179 |
+
"33W": "A", "34E": "V", "3AH": "H", "3BY": "P", "3CF": "F", "3CT": "Y",
|
| 180 |
+
"3GA": "A", "3GL": "E", "3MD": "D", "3MY": "Y", "3NF": "Y", "3O3": "E",
|
| 181 |
+
"3PX": "P", "3QN": "K", "3TT": "P", "3XH": "G", "3YM": "Y", "3WS": "A",
|
| 182 |
+
"3WX": "P", "3X9": "C", "3ZH": "H", "7JA": "I", "73C": "S", "73N": "R",
|
| 183 |
+
"73O": "Y", "73P": "K", "74P": "K", "7N8": "F", "7O5": "A", "7XC": "F",
|
| 184 |
+
"7ID": "D", "7OZ": "A", "C1S": "C", "C1T": "C", "C1X": "K", "C22": "A",
|
| 185 |
+
"C3Y": "C", "C4R": "C", "C5C": "C", "C6C": "C", "CAF": "C", "CAS": "C",
|
| 186 |
+
"CAY": "C", "CCS": "C", "CEA": "C", "CGA": "E", "CGU": "E", "CGV": "C",
|
| 187 |
+
"CHP": "G", "CIR": "R", "CLE": "L", "CLG": "K", "CLH": "K", "CME": "C",
|
| 188 |
+
"CMH": "C", "CML": "C", "CMT": "C", "CR5": "G", "CS0": "C", "CS1": "C",
|
| 189 |
+
"CS3": "C", "CS4": "C", "CSA": "C", "CSB": "C", "CSD": "C", "CSE": "C",
|
| 190 |
+
"CSJ": "C", "CSO": "C", "CSP": "C", "CSR": "C", "CSS": "C", "CSU": "C",
|
| 191 |
+
"CSW": "C", "CSX": "C", "CSZ": "C", "CTE": "W", "CTH": "T", "CWD": "A",
|
| 192 |
+
"CWR": "S", "CXM": "M", "CY0": "C", "CY1": "C", "CY3": "C", "CY4": "C",
|
| 193 |
+
"CYA": "C", "CYD": "C", "CYF": "C", "CYG": "C", "CYJ": "K", "CYM": "C",
|
| 194 |
+
"CYQ": "C", "CYR": "C", "CYS": "C", "CYW": "C", "CZ2": "C", "CZZ": "C",
|
| 195 |
+
"CG6": "C", "C1J": "R", "C4G": "R", "C67": "R", "C6D": "R", "CE7": "N",
|
| 196 |
+
"CZS": "A", "G01": "E", "G8M": "E", "GAU": "E", "GEE": "G", "GFT": "S",
|
| 197 |
+
"GHC": "E", "GHG": "Q", "GHW": "E", "GL3": "G", "GLH": "Q", "GLJ": "E",
|
| 198 |
+
"GLK": "E", "GLN": "Q", "GLQ": "E", "GLU": "E", "GLY": "G", "GLZ": "G",
|
| 199 |
+
"GMA": "E", "GME": "E", "GNC": "Q", "GPL": "K", "GSC": "G", "GSU": "E",
|
| 200 |
+
"GT9": "C", "GVL": "S", "G3M": "R", "G5G": "L", "G1X": "Y", "G8X": "P",
|
| 201 |
+
"K1R": "C", "KBE": "K", "KCX": "K", "KFP": "K", "KGC": "K", "KNB": "A",
|
| 202 |
+
"KOR": "M", "KPI": "K", "KPY": "K", "KST": "K", "KYN": "W", "KYQ": "K",
|
| 203 |
+
"KCR": "K", "KPF": "K", "K5L": "S", "KEO": "K", "KHB": "K", "KKD": "D",
|
| 204 |
+
"K5H": "C", "K7K": "S", "OAR": "R", "OAS": "S", "OBS": "K", "OCS": "C",
|
| 205 |
+
"OCY": "C", "OHI": "H", "OHS": "D", "OLD": "H", "OLT": "T", "OLZ": "S",
|
| 206 |
+
"OMH": "S", "OMT": "M", "OMX": "Y", "OMY": "Y", "ONH": "A", "ORN": "A",
|
| 207 |
+
"ORQ": "R", "OSE": "S", "OTH": "T", "OXX": "D", "OYL": "H", "O7A": "T",
|
| 208 |
+
"O7D": "W", "O7G": "V", "O2E": "S", "O6H": "W", "OZW": "F", "S12": "S",
|
| 209 |
+
"S1H": "S", "S2C": "C", "S2P": "A", "SAC": "S", "SAH": "C", "SAR": "G",
|
| 210 |
+
"SBG": "S", "SBL": "S", "SCH": "C", "SCS": "C", "SCY": "C", "SD4": "N",
|
| 211 |
+
"SDB": "S", "SDP": "S", "SEB": "S", "SEE": "S", "SEG": "A", "SEL": "S",
|
| 212 |
+
"SEM": "S", "SEN": "S", "SEP": "S", "SER": "S", "SET": "S", "SGB": "S",
|
| 213 |
+
"SHC": "C", "SHP": "G", "SHR": "K", "SIB": "C", "SLL": "K", "SLZ": "K",
|
| 214 |
+
"SMC": "C", "SME": "M", "SMF": "F", "SNC": "C", "SNN": "N", "SOY": "S",
|
| 215 |
+
"SRZ": "S", "STY": "Y", "SUN": "S", "SVA": "S", "SVV": "S", "SVW": "S",
|
| 216 |
+
"SVX": "S", "SVY": "S", "SVZ": "S", "SXE": "S", "SKH": "K", "SNM": "S",
|
| 217 |
+
"SNK": "H", "SWW": "S", "WFP": "F", "WLU": "L", "WPA": "F", "WRP": "W",
|
| 218 |
+
"WVL": "V", "02K": "A", "02L": "N", "02O": "A", "02Y": "A", "033": "V",
|
| 219 |
+
"037": "P", "03Y": "C", "04U": "P", "04V": "P", "05N": "P", "07O": "C",
|
| 220 |
+
"0A0": "D", "0A1": "Y", "0A2": "K", "0A8": "C", "0A9": "F", "0AA": "V",
|
| 221 |
+
"0AB": "V", "0AC": "G", "0AF": "W", "0AG": "L", "0AH": "S", "0AK": "D",
|
| 222 |
+
"0AR": "R", "0BN": "F", "0CS": "A", "0E5": "T", "0EA": "Y", "0FL": "A",
|
| 223 |
+
"0LF": "P", "0NC": "A", "0PR": "Y", "0QL": "C", "0TD": "D", "0UO": "W",
|
| 224 |
+
"0WZ": "Y", "0X9": "R", "0Y8": "P", "4AF": "F", "4AR": "R", "4AW": "W",
|
| 225 |
+
"4BF": "F", "4CF": "F", "4CY": "M", "4DP": "W", "4FB": "P", "4FW": "W",
|
| 226 |
+
"4HL": "Y", "4HT": "W", "4IN": "W", "4MM": "M", "4PH": "F", "4U7": "A",
|
| 227 |
+
"41H": "F", "41Q": "N", "42Y": "S", "432": "S", "45F": "P", "4AK": "K",
|
| 228 |
+
"4D4": "R", "4GJ": "C", "4KY": "P", "4L0": "P", "4LZ": "Y", "4N7": "P",
|
| 229 |
+
"4N8": "P", "4N9": "P", "4OG": "W", "4OU": "F", "4OV": "S", "4OZ": "S",
|
| 230 |
+
"4PQ": "W", "4SJ": "F", "4WQ": "A", "4HH": "S", "4HJ": "S", "4J4": "C",
|
| 231 |
+
"4J5": "R", "4II": "F", "4VI": "R", "823": "N", "8SP": "S", "8AY": "A",
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
# Nucleic Acids
|
| 235 |
+
nucleic_letters_3to1 = {
|
| 236 |
+
"A ": "A", "C ": "C", "G ": "G", "U ": "U",
|
| 237 |
+
"DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
rna_letters_3to1 = {
|
| 241 |
+
"A ": "A", "C ": "C", "G ": "G", "U ": "U",
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
dna_letters_3to1 = {
|
| 245 |
+
"DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
# fmt: off
|
| 249 |
+
nucleic_letters_3to1_extended = {
|
| 250 |
+
"A ": "A", "A23": "A", "A2L": "A", "A2M": "A", "A34": "A", "A35": "A",
|
| 251 |
+
"A38": "A", "A39": "A", "A3A": "A", "A3P": "A", "A40": "A", "A43": "A",
|
| 252 |
+
"A44": "A", "A47": "A", "A5L": "A", "A5M": "C", "A5O": "A", "A6A": "A",
|
| 253 |
+
"A6C": "C", "A6G": "G", "A6U": "U", "A7E": "A", "A9Z": "A", "ABR": "A",
|
| 254 |
+
"ABS": "A", "AD2": "A", "ADI": "A", "ADP": "A", "AET": "A", "AF2": "A",
|
| 255 |
+
"AFG": "G", "AMD": "A", "AMO": "A", "AP7": "A", "AS ": "A", "ATD": "T",
|
| 256 |
+
"ATL": "T", "ATM": "T", "AVC": "A", "AI5": "C", "E ": "A", "E1X": "A",
|
| 257 |
+
"EDA": "A", "EFG": "G", "EHG": "G", "EIT": "T", "EXC": "C", "E3C": "C",
|
| 258 |
+
"E6G": "G", "E7G": "G", "EQ4": "G", "EAN": "T", "I5C": "C", "IC ": "C",
|
| 259 |
+
"IG ": "G", "IGU": "G", "IMC": "C", "IMP": "G", "IU ": "U", "I4U": "U",
|
| 260 |
+
"IOO": "G", "M1G": "G", "M2G": "G", "M4C": "C", "M5M": "C", "MA6": "A",
|
| 261 |
+
"MA7": "A", "MAD": "A", "MCY": "C", "ME6": "C", "MEP": "U", "MG1": "G",
|
| 262 |
+
"MGQ": "A", "MGT": "G", "MGV": "G", "MIA": "A", "MMT": "T", "MNU": "U",
|
| 263 |
+
"MRG": "G", "MTR": "T", "MTU": "A", "MFO": "G", "M7A": "A", "MHG": "G",
|
| 264 |
+
"MMX": "C", "QUO": "G", "QCK": "T", "QSQ": "A", "U ": "U", "U25": "U",
|
| 265 |
+
"U2L": "U", "U2P": "U", "U31": "U", "U34": "U", "U36": "U", "U37": "U",
|
| 266 |
+
"U8U": "U", "UAR": "U", "UBB": "U", "UBD": "U", "UD5": "U", "UPV": "U",
|
| 267 |
+
"UR3": "U", "URD": "U", "US3": "T", "US5": "U", "UZR": "U", "UMO": "U",
|
| 268 |
+
"U23": "U", "U48": "C", "U7B": "C", "Y ": "A", "YCO": "C", "YG ": "G",
|
| 269 |
+
"YYG": "G", "23G": "G", "26A": "A", "2AR": "A", "2AT": "T", "2AU": "U",
|
| 270 |
+
"2BT": "T", "2BU": "A", "2DA": "A", "2DT": "T", "2EG": "G", "2GT": "T",
|
| 271 |
+
"2JV": "G", "2MA": "A", "2MG": "G", "2MU": "U", "2NT": "T", "2OM": "U",
|
| 272 |
+
"2OT": "T", "2PR": "G", "2SG": "G", "2ST": "T", "63G": "G", "63H": "G",
|
| 273 |
+
"64T": "T", "68Z": "G", "6CT": "T", "6HA": "A", "6HB": "A", "6HC": "C",
|
| 274 |
+
"6HG": "G", "6HT": "T", "6IA": "A", "6MA": "A", "6MC": "A", "6MP": "A",
|
| 275 |
+
"6MT": "A", "6MZ": "A", "6OG": "G", "6PO": "G", "6FK": "G", "6NW": "A",
|
| 276 |
+
"6OO": "C", "D00": "C", "D3T": "T", "D4M": "T", "DA ": "A", "DC ": "C",
|
| 277 |
+
"DCG": "G", "DCT": "C", "DDG": "G", "DFC": "C", "DFG": "G", "DG ": "G",
|
| 278 |
+
"DG8": "G", "DGI": "G", "DGP": "G", "DHU": "U", "DNR": "C", "DOC": "C",
|
| 279 |
+
"DPB": "T", "DRT": "T", "DT ": "T", "DZM": "A", "D4B": "C", "H2U": "U",
|
| 280 |
+
"HN0": "G", "HN1": "G", "LC ": "C", "LCA": "A", "LCG": "G", "LG ": "G",
|
| 281 |
+
"LGP": "G", "LHU": "U", "LSH": "T", "LST": "T", "LDG": "G", "L3X": "A",
|
| 282 |
+
"LHH": "C", "LV2": "C", "L1J": "G", "P ": "G", "P2T": "T", "P5P": "A",
|
| 283 |
+
"PG7": "G", "PGN": "G", "PGP": "G", "PMT": "C", "PPU": "A", "PPW": "G",
|
| 284 |
+
"PR5": "A", "PRN": "A", "PST": "T", "PSU": "U", "PU ": "A", "PVX": "C",
|
| 285 |
+
"PYO": "U", "PZG": "G", "P4U": "U", "P7G": "G", "T ": "T", "T2S": "T",
|
| 286 |
+
"T31": "U", "T32": "T", "T36": "T", "T37": "T", "T38": "T", "T39": "T",
|
| 287 |
+
"T3P": "T", "T41": "T", "T48": "T", "T49": "T", "T4S": "T", "T5S": "T",
|
| 288 |
+
"T64": "T", "T6A": "A", "TA3": "T", "TAF": "T", "TBN": "A", "TC1": "C",
|
| 289 |
+
"TCP": "T", "TCY": "A", "TDY": "T", "TED": "T", "TFE": "T", "TFF": "T",
|
| 290 |
+
"TFO": "A", "TFT": "T", "TGP": "G", "TCJ": "C", "TLC": "T", "TP1": "T",
|
| 291 |
+
"TPC": "C", "TPG": "G", "TSP": "T", "TTD": "T", "TTM": "T", "TXD": "A",
|
| 292 |
+
"TXP": "A", "TC ": "C", "TG ": "G", "T0N": "G", "T0Q": "G", "X ": "G",
|
| 293 |
+
"XAD": "A", "XAL": "A", "XCL": "C", "XCR": "C", "XCT": "C", "XCY": "C",
|
| 294 |
+
"XGL": "G", "XGR": "G", "XGU": "G", "XPB": "G", "XTF": "T", "XTH": "T",
|
| 295 |
+
"XTL": "T", "XTR": "T", "XTS": "G", "XUA": "A", "XUG": "G", "102": "G",
|
| 296 |
+
"10C": "C", "125": "U", "126": "U", "127": "U", "12A": "A", "16B": "C",
|
| 297 |
+
"18M": "G", "1AP": "A", "1CC": "C", "1FC": "C", "1MA": "A", "1MG": "G",
|
| 298 |
+
"1RN": "U", "1SC": "C", "5AA": "A", "5AT": "T", "5BU": "U", "5CG": "G",
|
| 299 |
+
"5CM": "C", "5FA": "A", "5FC": "C", "5FU": "U", "5HC": "C", "5HM": "C",
|
| 300 |
+
"5HT": "T", "5IC": "C", "5IT": "T", "5MC": "C", "5MU": "U", "5NC": "C",
|
| 301 |
+
"5PC": "C", "5PY": "T", "9QV": "U", "94O": "T", "9SI": "A", "9SY": "A",
|
| 302 |
+
"B7C": "C", "BGM": "G", "BOE": "T", "B8H": "U", "B8K": "G", "B8Q": "C",
|
| 303 |
+
"B8T": "C", "B8W": "G", "B9B": "G", "B9H": "C", "BGH": "G", "F3H": "T",
|
| 304 |
+
"F3N": "A", "F4H": "T", "FA2": "A", "FDG": "G", "FHU": "U", "FMG": "G",
|
| 305 |
+
"FNU": "U", "FOX": "G", "F2T": "U", "F74": "G", "F4Q": "G", "F7H": "C",
|
| 306 |
+
"F7K": "G", "JDT": "T", "JMH": "C", "J0X": "C", "N5M": "C", "N6G": "G",
|
| 307 |
+
"N79": "A", "NCU": "C", "NMS": "T", "NMT": "T", "NTT": "T", "N7X": "C",
|
| 308 |
+
"R ": "A", "RBD": "A", "RDG": "G", "RIA": "A", "RMP": "A", "RPC": "C",
|
| 309 |
+
"RSP": "C", "RSQ": "C", "RT ": "T", "RUS": "U", "RFJ": "G", "V3L": "A",
|
| 310 |
+
"VC7": "G", "Z ": "C", "ZAD": "A", "ZBC": "C", "ZBU": "U", "ZCY": "C",
|
| 311 |
+
"ZGU": "G", "31H": "A", "31M": "A", "3AU": "U", "3DA": "A", "3ME": "U",
|
| 312 |
+
"3MU": "U", "3TD": "U", "70U": "U", "7AT": "A", "7DA": "A", "7GU": "G",
|
| 313 |
+
"7MG": "G", "7BG": "G", "73W": "C", "75B": "U", "7OK": "C", "7S3": "G",
|
| 314 |
+
"7SN": "G", "C ": "C", "C25": "C", "C2L": "C", "C2S": "C", "C31": "C",
|
| 315 |
+
"C32": "C", "C34": "C", "C36": "C", "C37": "C", "C38": "C", "C42": "C",
|
| 316 |
+
"C43": "C", "C45": "C", "C46": "C", "C49": "C", "C4S": "C", "C5L": "C",
|
| 317 |
+
"C6G": "G", "CAR": "C", "CB2": "C", "CBR": "C", "CBV": "C", "CCC": "C",
|
| 318 |
+
"CDW": "C", "CFL": "C", "CFZ": "C", "CG1": "G", "CH ": "C", "CMR": "C",
|
| 319 |
+
"CNU": "U", "CP1": "C", "CSF": "C", "CSL": "C", "CTG": "T", "CX2": "C",
|
| 320 |
+
"C7S": "C", "C7R": "C", "G ": "G", "G1G": "G", "G25": "G", "G2L": "G",
|
| 321 |
+
"G2S": "G", "G31": "G", "G32": "G", "G33": "G", "G36": "G", "G38": "G",
|
| 322 |
+
"G42": "G", "G46": "G", "G47": "G", "G48": "G", "G49": "G", "G7M": "G",
|
| 323 |
+
"GAO": "G", "GCK": "C", "GDO": "G", "GDP": "G", "GDR": "G", "GF2": "G",
|
| 324 |
+
"GFL": "G", "GH3": "G", "GMS": "G", "GN7": "G", "GNG": "G", "GOM": "G",
|
| 325 |
+
"GRB": "G", "GS ": "G", "GSR": "G", "GSS": "G", "GTP": "G", "GX1": "G",
|
| 326 |
+
"KAG": "G", "KAK": "G", "O2G": "G", "OGX": "G", "OMC": "C", "OMG": "G",
|
| 327 |
+
"OMU": "U", "ONE": "U", "O2Z": "A", "OKN": "C", "OKQ": "C", "S2M": "T",
|
| 328 |
+
"S4A": "A", "S4C": "C", "S4G": "G", "S4U": "U", "S6G": "G", "SC ": "C",
|
| 329 |
+
"SDE": "A", "SDG": "G", "SDH": "G", "SMP": "A", "SMT": "T", "SPT": "T",
|
| 330 |
+
"SRA": "A", "SSU": "U", "SUR": "U", "00A": "A", "0AD": "G", "0AM": "A",
|
| 331 |
+
"0AP": "C", "0AV": "A", "0R8": "C", "0SP": "A", "0UH": "G", "47C": "C",
|
| 332 |
+
"4OC": "C", "4PC": "C", "4PD": "C", "4PE": "C", "4SC": "C", "4SU": "U",
|
| 333 |
+
"45A": "A", "4U3": "C", "8AG": "G", "8AN": "A", "8BA": "A", "8FG": "G",
|
| 334 |
+
"8MG": "G", "8OG": "G", "8PY": "G", "8AA": "G", "85Y": "U", "8OS": "G",
|
| 335 |
+
"UNK": "X", # DEBUG
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
standard_protein_letters_3to1 = protein_letters_3to1
|
| 339 |
+
standard_protein_letters_1to3 = protein_letters_1to3
|
| 340 |
+
nonstandard_protein_letters_3to1 = {k: v for k, v in protein_letters_3to1_extended.items() if
|
| 341 |
+
k not in standard_protein_letters_3to1}
|
| 342 |
+
|
| 343 |
+
standard_nucleic_letters_3to1 = nucleic_letters_3to1
|
| 344 |
+
standard_nucleic_letters_1to3 = {v: k for k, v in standard_nucleic_letters_3to1.items()}
|
| 345 |
+
nonstandard_nucleic_letters_3to1 = {k: v for k, v in nucleic_letters_3to1_extended.items() if
|
| 346 |
+
k not in standard_nucleic_letters_3to1}
|
| 347 |
+
|
| 348 |
+
letters_3to1_extended = {**protein_letters_3to1_extended, **nucleic_letters_3to1_extended}
|
PhysDock/data/tools/__init__.py
ADDED
|
File without changes
|
PhysDock/data/tools/alignment_runner.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os.path
|
| 3 |
+
from functools import partial
|
| 4 |
+
import tqdm
|
| 5 |
+
from typing import Optional, Mapping, Any, Union
|
| 6 |
+
|
| 7 |
+
from PhysDock.data.tools import jackhmmer, nhmmer, hhblits, kalign, hmmalign, parsers, hmmbuild, hhsearch, templates
|
| 8 |
+
from PhysDock.utils.io_utils import load_pkl, load_txt, load_json, run_pool_tasks, convert_md5_string, dump_pkl
|
| 9 |
+
from PhysDock.data.tools.parsers import parse_fasta
|
| 10 |
+
|
| 11 |
+
TemplateSearcher = Union[hhsearch.HHSearch]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AlignmentRunner:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
# Homo Search Tools
|
| 18 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 19 |
+
hhblits_binary_path: Optional[str] = None,
|
| 20 |
+
nhmmer_binary_path: Optional[str] = None,
|
| 21 |
+
hmmbuild_binary_path: Optional[str] = None,
|
| 22 |
+
hmmalign_binary_path: Optional[str] = None,
|
| 23 |
+
kalign_binary_path: Optional[str] = None,
|
| 24 |
+
|
| 25 |
+
# Templ Search Tools
|
| 26 |
+
hhsearch_binary_path: Optional[str] = None,
|
| 27 |
+
template_searcher: Optional[TemplateSearcher] = None,
|
| 28 |
+
template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
|
| 29 |
+
|
| 30 |
+
# Databases
|
| 31 |
+
uniref90_database_path: Optional[str] = None,
|
| 32 |
+
uniprot_database_path: Optional[str] = None,
|
| 33 |
+
uniclust30_database_path: Optional[str] = None,
|
| 34 |
+
uniref30_database_path: Optional[str] = None,
|
| 35 |
+
bfd_database_path: Optional[str] = None,
|
| 36 |
+
reduced_bfd_database_path: Optional[str] = None,
|
| 37 |
+
mgnify_database_path: Optional[str] = None,
|
| 38 |
+
rfam_database_path: Optional[str] = None,
|
| 39 |
+
rnacentral_database_path: Optional[str] = None,
|
| 40 |
+
nt_database_path: Optional[str] = None,
|
| 41 |
+
#
|
| 42 |
+
no_cpus: int = 8,
|
| 43 |
+
# Limitations
|
| 44 |
+
uniref90_seq_limit: int = 100000,
|
| 45 |
+
uniprot_seq_limit: int = 500000,
|
| 46 |
+
reduced_bfd_seq_limit: int = 50000,
|
| 47 |
+
mgnify_seq_limit: int = 50000,
|
| 48 |
+
uniref90_max_hits: int = 10000,
|
| 49 |
+
uniprot_max_hits: int = 50000,
|
| 50 |
+
reduced_bfd_max_hits: int = 5000,
|
| 51 |
+
mgnify_max_hits: int = 5000,
|
| 52 |
+
rfam_max_hits: int = 10000,
|
| 53 |
+
rnacentral_max_hits: int = 10000,
|
| 54 |
+
nt_max_hits: int = 10000,
|
| 55 |
+
):
|
| 56 |
+
self.uniref90_jackhmmer_runner = None
|
| 57 |
+
self.uniprot_jackhmmer_runner = None
|
| 58 |
+
self.reduced_bfd_jackhmmer_runner = None
|
| 59 |
+
self.mgnify_jackhmmer_runner = None
|
| 60 |
+
self.bfd_uniref30_hhblits_runner = None
|
| 61 |
+
self.bfd_uniclust30_hhblits_runner = None
|
| 62 |
+
self.rfam_nhmmer_runner = None
|
| 63 |
+
self.rnacentral_nhmmer_runner = None
|
| 64 |
+
self.nt_nhmmer_runner = None
|
| 65 |
+
self.rna_realign_runner = None
|
| 66 |
+
self.template_searcher = template_searcher
|
| 67 |
+
self.template_featurizer = template_featurizer
|
| 68 |
+
|
| 69 |
+
def _all_exists(*objs, hhblits_mode=False):
|
| 70 |
+
if not hhblits_mode:
|
| 71 |
+
for obj in objs:
|
| 72 |
+
if obj is None or not os.path.exists(obj):
|
| 73 |
+
return False
|
| 74 |
+
else:
|
| 75 |
+
for obj in objs:
|
| 76 |
+
if obj is None or not os.path.exists(os.path.split(obj)[0]):
|
| 77 |
+
return False
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
def _run_msa_tool(
|
| 81 |
+
fasta_path: str,
|
| 82 |
+
msa_out_path: str,
|
| 83 |
+
msa_runner,
|
| 84 |
+
msa_format: str,
|
| 85 |
+
max_sto_sequences: Optional[int] = None,
|
| 86 |
+
) -> Mapping[str, Any]:
|
| 87 |
+
"""Runs an MSA tool, checking if output already exists first."""
|
| 88 |
+
if (msa_format == "sto" and max_sto_sequences is not None):
|
| 89 |
+
result = msa_runner.query(fasta_path, max_sto_sequences)[0]
|
| 90 |
+
else:
|
| 91 |
+
result = msa_runner.query(fasta_path)[0]
|
| 92 |
+
|
| 93 |
+
assert msa_out_path.split('.')[-1] == msa_format
|
| 94 |
+
with open(msa_out_path, "w") as f:
|
| 95 |
+
f.write(result[msa_format])
|
| 96 |
+
|
| 97 |
+
return result
|
| 98 |
+
|
| 99 |
+
def _run_rna_realign_tool(
|
| 100 |
+
fasta_path: str,
|
| 101 |
+
msa_in_path: str,
|
| 102 |
+
msa_out_path: str,
|
| 103 |
+
use_precompute=True,
|
| 104 |
+
):
|
| 105 |
+
runner = hmmalign.Hmmalign(
|
| 106 |
+
hmmbuild_binary_path=hmmbuild_binary_path,
|
| 107 |
+
hmmalign_binary_path=hmmalign_binary_path,
|
| 108 |
+
)
|
| 109 |
+
if os.path.exists(msa_in_path) and os.path.getsize(msa_in_path) == 0:
|
| 110 |
+
# print("MSA sto file is 0")
|
| 111 |
+
with open(msa_out_path, "w") as f:
|
| 112 |
+
pass
|
| 113 |
+
return
|
| 114 |
+
if use_precompute:
|
| 115 |
+
if os.path.exists(msa_in_path) and os.path.exists(msa_out_path):
|
| 116 |
+
if os.path.getsize(msa_in_path) > 0 and os.path.getsize(msa_out_path) == 0:
|
| 117 |
+
logging.warning(f"The msa realign file size is zero but the origin file size is over 0! "
|
| 118 |
+
f"fasta: {fasta_path} msa_in_file: {msa_in_path}")
|
| 119 |
+
runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
|
| 120 |
+
else:
|
| 121 |
+
runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
|
| 122 |
+
else:
|
| 123 |
+
runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
|
| 124 |
+
# with open(msa_out_path, "w") as f:
|
| 125 |
+
# f.write(msa_out)
|
| 126 |
+
|
| 127 |
+
assert uniclust30_database_path is None or uniref30_database_path is None, "Only one used"
|
| 128 |
+
|
| 129 |
+
# Jackhmmer
|
| 130 |
+
if _all_exists(jackhmmer_binary_path, uniref90_database_path):
|
| 131 |
+
self.uniref90_jackhmmer_runner = partial(
|
| 132 |
+
_run_msa_tool,
|
| 133 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 134 |
+
binary_path=jackhmmer_binary_path,
|
| 135 |
+
database_path=uniref90_database_path,
|
| 136 |
+
seq_limit=uniref90_seq_limit,
|
| 137 |
+
n_cpu=no_cpus,
|
| 138 |
+
),
|
| 139 |
+
msa_format="sto",
|
| 140 |
+
max_sto_sequences=uniref90_max_hits
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if _all_exists(jackhmmer_binary_path, uniprot_database_path):
|
| 144 |
+
self.uniprot_jackhmmer_runner = partial(
|
| 145 |
+
_run_msa_tool,
|
| 146 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 147 |
+
binary_path=jackhmmer_binary_path,
|
| 148 |
+
database_path=uniprot_database_path,
|
| 149 |
+
seq_limit=uniprot_seq_limit,
|
| 150 |
+
n_cpu=no_cpus,
|
| 151 |
+
),
|
| 152 |
+
msa_format="sto",
|
| 153 |
+
max_sto_sequences=uniprot_max_hits
|
| 154 |
+
)
|
| 155 |
+
if _all_exists(jackhmmer_binary_path, reduced_bfd_database_path):
|
| 156 |
+
self.reduced_bfd_jackhmmer_runner = partial(
|
| 157 |
+
_run_msa_tool,
|
| 158 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 159 |
+
binary_path=jackhmmer_binary_path,
|
| 160 |
+
database_path=reduced_bfd_database_path,
|
| 161 |
+
seq_limit=reduced_bfd_seq_limit,
|
| 162 |
+
n_cpu=no_cpus,
|
| 163 |
+
),
|
| 164 |
+
msa_format="sto",
|
| 165 |
+
max_sto_sequences=reduced_bfd_max_hits
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if _all_exists(jackhmmer_binary_path, mgnify_database_path):
|
| 169 |
+
self.mgnify_jackhmmer_runner = partial(
|
| 170 |
+
_run_msa_tool,
|
| 171 |
+
msa_runner=jackhmmer.Jackhmmer(
|
| 172 |
+
binary_path=jackhmmer_binary_path,
|
| 173 |
+
database_path=mgnify_database_path,
|
| 174 |
+
seq_limit=mgnify_seq_limit,
|
| 175 |
+
n_cpu=no_cpus,
|
| 176 |
+
),
|
| 177 |
+
msa_format="sto",
|
| 178 |
+
max_sto_sequences=mgnify_max_hits
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# HHblits
|
| 182 |
+
if _all_exists(hhblits_binary_path, bfd_database_path, uniref30_database_path, hhblits_mode=True):
|
| 183 |
+
self.bfd_uniref30_hhblits_runner = partial(
|
| 184 |
+
_run_msa_tool,
|
| 185 |
+
msa_runner=hhblits.HHBlits(
|
| 186 |
+
binary_path=hhblits_binary_path,
|
| 187 |
+
databases=[bfd_database_path, uniref30_database_path],
|
| 188 |
+
n_cpu=no_cpus,
|
| 189 |
+
),
|
| 190 |
+
msa_format="a3m",
|
| 191 |
+
)
|
| 192 |
+
elif _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
|
| 193 |
+
self.bfd_uniclust30_hhblits_runner = partial(
|
| 194 |
+
_run_msa_tool,
|
| 195 |
+
msa_runner=hhblits.HHBlits(
|
| 196 |
+
binary_path=hhblits_binary_path,
|
| 197 |
+
databases=[bfd_database_path, uniclust30_database_path],
|
| 198 |
+
n_cpu=no_cpus,
|
| 199 |
+
),
|
| 200 |
+
msa_format="a3m",
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Nhmmer
|
| 204 |
+
if _all_exists(nhmmer_binary_path, rfam_database_path):
|
| 205 |
+
self.rfam_nhmmer_runner = partial(
|
| 206 |
+
_run_msa_tool,
|
| 207 |
+
msa_runner=nhmmer.Nhmmer(
|
| 208 |
+
binary_path=nhmmer_binary_path,
|
| 209 |
+
database_path=rfam_database_path,
|
| 210 |
+
n_cpu=no_cpus
|
| 211 |
+
),
|
| 212 |
+
msa_format="sto",
|
| 213 |
+
max_sto_sequences=rfam_max_hits
|
| 214 |
+
)
|
| 215 |
+
if _all_exists(nhmmer_binary_path, rnacentral_database_path):
|
| 216 |
+
self.rnacentral_nhmmer_runner = partial(
|
| 217 |
+
_run_msa_tool,
|
| 218 |
+
msa_runner=nhmmer.Nhmmer(
|
| 219 |
+
binary_path=nhmmer_binary_path,
|
| 220 |
+
database_path=rnacentral_database_path,
|
| 221 |
+
n_cpu=no_cpus
|
| 222 |
+
),
|
| 223 |
+
msa_format="sto",
|
| 224 |
+
max_sto_sequences=rnacentral_max_hits
|
| 225 |
+
)
|
| 226 |
+
if _all_exists(nhmmer_binary_path, nt_database_path):
|
| 227 |
+
self.nt_nhmmer_runner = partial(
|
| 228 |
+
_run_msa_tool,
|
| 229 |
+
msa_runner=nhmmer.Nhmmer(
|
| 230 |
+
binary_path=nhmmer_binary_path,
|
| 231 |
+
database_path=nt_database_path,
|
| 232 |
+
n_cpu=no_cpus
|
| 233 |
+
),
|
| 234 |
+
msa_format="sto",
|
| 235 |
+
max_sto_sequences=nt_max_hits
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# def _run_rna_hmm(
|
| 239 |
+
# fasta_path: str,
|
| 240 |
+
# hmm_out_path: str,
|
| 241 |
+
# ):
|
| 242 |
+
# runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
|
| 243 |
+
# hmm = runner.build_rna_profile_from_fasta(fasta_path)
|
| 244 |
+
# with open(hmm_out_path, "w") as f:
|
| 245 |
+
# f.write(hmm)
|
| 246 |
+
|
| 247 |
+
if _all_exists(hmmbuild_binary_path, hmmalign_binary_path):
|
| 248 |
+
self.rna_realign_runner = _run_rna_realign_tool
|
| 249 |
+
|
| 250 |
+
def run(self, input_fasta_path, output_msas_dir, use_precompute=True):
|
| 251 |
+
os.makedirs(output_msas_dir, exist_ok=True)
|
| 252 |
+
templates_out_path = os.path.join(output_msas_dir, "templates")
|
| 253 |
+
uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
|
| 254 |
+
uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
|
| 255 |
+
reduced_bfd_out_path = os.path.join(output_msas_dir, "reduced_bfd_hits.sto")
|
| 256 |
+
mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
|
| 257 |
+
bfd_uniref30_out_path = os.path.join(output_msas_dir, f"bfd_uniref30_hits.a3m")
|
| 258 |
+
bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
|
| 259 |
+
|
| 260 |
+
seqs, decs = parse_fasta(load_txt(input_fasta_path))
|
| 261 |
+
prefix = "protein"
|
| 262 |
+
md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 263 |
+
output_feature = os.path.dirname(output_msas_dir)
|
| 264 |
+
output_feature = os.path.dirname(output_feature)
|
| 265 |
+
pkl_save_path_msa = os.path.join(output_feature, "msa_features", f"{md5}.pkl.gz")
|
| 266 |
+
pkl_save_path_msa_uni = os.path.join(output_feature, "uniprot_msa_features", f"{md5}.pkl.gz")
|
| 267 |
+
pkl_save_path_temp = os.path.join(output_feature, "template_features", f"{md5}.pkl.gz")
|
| 268 |
+
|
| 269 |
+
if self.uniref90_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_temp):
|
| 270 |
+
if not os.path.exists(uniref90_out_path) or not use_precompute or not os.path.exists(pkl_save_path_temp):
|
| 271 |
+
if not os.path.exists(uniref90_out_path):
|
| 272 |
+
print(uniref90_out_path)
|
| 273 |
+
self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
|
| 274 |
+
|
| 275 |
+
print("begin templates")
|
| 276 |
+
if templates_out_path is not None \
|
| 277 |
+
and self.template_searcher is not None and self.template_featurizer is not None:
|
| 278 |
+
try:
|
| 279 |
+
os.makedirs(templates_out_path, exist_ok=True)
|
| 280 |
+
seq, dec = parsers.parse_fasta(load_txt(input_fasta_path))
|
| 281 |
+
input_sequence = seq[0]
|
| 282 |
+
msa_for_templates = parsers.truncate_stockholm_msa(
|
| 283 |
+
uniref90_out_path, max_sequences=10000
|
| 284 |
+
)
|
| 285 |
+
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
|
| 286 |
+
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
|
| 287 |
+
msa_for_templates
|
| 288 |
+
)
|
| 289 |
+
if self.template_searcher.input_format == "sto":
|
| 290 |
+
pdb_templates_result = self.template_searcher.query(msa_for_templates)
|
| 291 |
+
elif self.template_searcher.input_format == "a3m":
|
| 292 |
+
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
|
| 293 |
+
pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError(
|
| 296 |
+
"Unrecognized template input format: "
|
| 297 |
+
f"{self.template_searcher.input_format}"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
pdb_hits_out_path = os.path.join(
|
| 301 |
+
templates_out_path, f"pdb_hits.{self.template_searcher.output_format}.pkl.gz"
|
| 302 |
+
)
|
| 303 |
+
with open(os.path.join(
|
| 304 |
+
templates_out_path, f"pdb_hits.{self.template_searcher.output_format}"
|
| 305 |
+
), "w") as f:
|
| 306 |
+
f.write(pdb_templates_result)
|
| 307 |
+
|
| 308 |
+
pdb_template_hits = self.template_searcher.get_template_hits(
|
| 309 |
+
output_string=pdb_templates_result, input_sequence=input_sequence
|
| 310 |
+
)
|
| 311 |
+
templates_result = self.template_featurizer.get_templates(
|
| 312 |
+
query_sequence=input_sequence, hits=pdb_template_hits
|
| 313 |
+
)
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logging.exception("An error in template searching")
|
| 316 |
+
|
| 317 |
+
dump_pkl(templates_result.features, pdb_hits_out_path, compress=True)
|
| 318 |
+
if self.uniprot_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa_uni):
|
| 319 |
+
if not os.path.exists(uniprot_out_path) or not use_precompute:
|
| 320 |
+
self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
|
| 321 |
+
if self.reduced_bfd_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 322 |
+
if not os.path.exists(reduced_bfd_out_path) or not use_precompute:
|
| 323 |
+
self.reduced_bfd_jackhmmer_runner(input_fasta_path, reduced_bfd_out_path)
|
| 324 |
+
if self.mgnify_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 325 |
+
if not os.path.exists(mgnify_out_path) or not use_precompute:
|
| 326 |
+
self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
|
| 327 |
+
if self.bfd_uniref30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 328 |
+
if not os.path.exists(bfd_uniref30_out_path) or not use_precompute:
|
| 329 |
+
self.bfd_uniref30_hhblits_runner(input_fasta_path, bfd_uniref30_out_path)
|
| 330 |
+
if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
|
| 331 |
+
if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
|
| 332 |
+
self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
|
| 333 |
+
# if self.rfam_nhmmer_runner is not None:
|
| 334 |
+
# if not os.path.exists(rfam_out_path) or not use_precompute:
|
| 335 |
+
# self.rfam_nhmmer_runner(input_fasta_path, rfam_out_path)
|
| 336 |
+
# # print(self.rna_realign_runner is not None, os.path.exists(rfam_out_path))
|
| 337 |
+
# if self.rna_realign_runner is not None and os.path.exists(rfam_out_path):
|
| 338 |
+
# self.rna_realign_runner(input_fasta_path, rfam_out_path, rfam_out_realigned_path)
|
| 339 |
+
# if self.rnacentral_nhmmer_runner is not None:
|
| 340 |
+
# if not os.path.exists(rnacentral_out_path) or not use_precompute:
|
| 341 |
+
# self.rnacentral_nhmmer_runner(input_fasta_path, rnacentral_out_path)
|
| 342 |
+
# if self.rna_realign_runner is not None and os.path.exists(rnacentral_out_path):
|
| 343 |
+
# self.rna_realign_runner(input_fasta_path, rnacentral_out_path, rnacentral_out_realigned_path)
|
| 344 |
+
# if self.nt_nhmmer_runner is not None:
|
| 345 |
+
# if not os.path.exists(nt_out_path) or not use_precompute:
|
| 346 |
+
# self.nt_nhmmer_runner(input_fasta_path, nt_out_path)
|
| 347 |
+
# if self.rna_realign_runner is not None and os.path.exists(nt_out_path):
|
| 348 |
+
# # print("realign",nt_out_path,nt_out_realigned_path)
|
| 349 |
+
# self.rna_realign_runner(input_fasta_path, nt_out_path, nt_out_realigned_path)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class DataProcessor:
|
| 353 |
+
def __init__(
|
| 354 |
+
self,
|
| 355 |
+
alphafold3_database_path,
|
| 356 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 357 |
+
hhblits_binary_path: Optional[str] = None,
|
| 358 |
+
nhmmer_binary_path: Optional[str] = None,
|
| 359 |
+
kalign_binary_path: Optional[str] = None,
|
| 360 |
+
hmmbuild_binary_path: Optional[str] = None,
|
| 361 |
+
hmmalign_binary_path: Optional[str] = None,
|
| 362 |
+
hhsearch_binary_path: Optional[str] = None,
|
| 363 |
+
template_searcher: Optional[TemplateSearcher] = None,
|
| 364 |
+
template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
|
| 365 |
+
n_cpus: int = 8,
|
| 366 |
+
n_workers: int = 1,
|
| 367 |
+
):
|
| 368 |
+
'''
|
| 369 |
+
Database Versions:
|
| 370 |
+
Training:
|
| 371 |
+
uniref90: v2022_05
|
| 372 |
+
bfd:
|
| 373 |
+
reduces_bfd:
|
| 374 |
+
uniclust30: v2018_08
|
| 375 |
+
uniprot: v2020_05
|
| 376 |
+
mgnify: v2022_05
|
| 377 |
+
rfam: v14.9
|
| 378 |
+
rnacentral: v21.0
|
| 379 |
+
nt: v2023_02_23
|
| 380 |
+
Inference:
|
| 381 |
+
uniref90: v2022_05
|
| 382 |
+
bfd:
|
| 383 |
+
reduces_bfd:
|
| 384 |
+
uniclust30: v2018_08
|
| 385 |
+
uniprot: v2021_04 *
|
| 386 |
+
mgnify: v2022_05
|
| 387 |
+
rfam: v14.9
|
| 388 |
+
rnacentral: v21.0
|
| 389 |
+
nt: v2023_02_23
|
| 390 |
+
Inference Ligand:
|
| 391 |
+
uniref90: v2020_01 *
|
| 392 |
+
bfd:
|
| 393 |
+
reduces_bfd:
|
| 394 |
+
uniclust30: v2018_08
|
| 395 |
+
uniprot: v2020_05
|
| 396 |
+
mgnify: v2018_12 *
|
| 397 |
+
rfam: v14.9
|
| 398 |
+
rnacentral: v21.0
|
| 399 |
+
nt: v2023_02_23
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
alphafold3_database_path: Database dir that contains all alphafold3 databases
|
| 403 |
+
jackhmmer_binary_path:
|
| 404 |
+
hhblits_binary_path:
|
| 405 |
+
nhmmer_binary_path:
|
| 406 |
+
kalign_binary_path:
|
| 407 |
+
hmmaligh_binary_path:
|
| 408 |
+
n_cpus:
|
| 409 |
+
n_workers:
|
| 410 |
+
'''
|
| 411 |
+
self.jackhmmer_binary_path = jackhmmer_binary_path
|
| 412 |
+
self.hhblits_binary_path = hhblits_binary_path
|
| 413 |
+
self.nhmmer_binary_path = nhmmer_binary_path
|
| 414 |
+
self.hmmbuild_binary_path = hmmbuild_binary_path
|
| 415 |
+
self.hmmalign_binary_path = hmmalign_binary_path
|
| 416 |
+
self.hhsearch_binary_path = hhsearch_binary_path
|
| 417 |
+
|
| 418 |
+
self.template_searcher = template_searcher
|
| 419 |
+
self.template_featurizer = template_featurizer
|
| 420 |
+
|
| 421 |
+
self.n_cpus = n_cpus
|
| 422 |
+
self.n_workers = n_workers
|
| 423 |
+
|
| 424 |
+
self.uniref90_database_path = os.path.join(
|
| 425 |
+
alphafold3_database_path, "uniref90", "uniref90.fasta"
|
| 426 |
+
)
|
| 427 |
+
################### TODO: DEBUG
|
| 428 |
+
self.uniprot_database_path = os.path.join(
|
| 429 |
+
alphafold3_database_path, "uniprot", "uniprot.fasta"
|
| 430 |
+
)
|
| 431 |
+
self.bfd_database_path = os.path.join(
|
| 432 |
+
alphafold3_database_path, "bfd", "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
|
| 433 |
+
)
|
| 434 |
+
self.uniclust30_database_path = os.path.join(
|
| 435 |
+
alphafold3_database_path, "uniclust30", "uniclust30_2018_08", "uniclust30_2018_08"
|
| 436 |
+
)
|
| 437 |
+
################### TODO: check alphafold2 multimer uniref30 version
|
| 438 |
+
self.uniref_30_database_path = os.path.join(
|
| 439 |
+
alphafold3_database_path, "uniref30", "v2020_06"
|
| 440 |
+
)
|
| 441 |
+
# self.reduced_bfd_database_path = os.path.join(
|
| 442 |
+
# alphafold3_database_path,"reduced_bfd"
|
| 443 |
+
# )
|
| 444 |
+
|
| 445 |
+
self.mgnify_database_path = os.path.join(
|
| 446 |
+
alphafold3_database_path, "mgnify", "mgnify", "mgy_clusters.fa"
|
| 447 |
+
)
|
| 448 |
+
self.rfam_database_path = os.path.join(
|
| 449 |
+
alphafold3_database_path, "rfam", "v14.9", "Rfam_af3_clustered_rep_seq.fasta"
|
| 450 |
+
)
|
| 451 |
+
self.rnacentral_database_path = os.path.join(
|
| 452 |
+
alphafold3_database_path, "rnacentral", "v21.0", "rnacentral_db_rep_seq.fasta"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
self.nt_database_path = os.path.join(
|
| 456 |
+
# alphafold3_database_path, "nt", "v2023_02_23", "nt_af3_clustered_rep_seq.fasta" # DEBUG
|
| 457 |
+
alphafold3_database_path, "nt", "v2023_02_23", "nt.fasta"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
self.runner_args_map = {
|
| 461 |
+
"uniref90": {
|
| 462 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 463 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 464 |
+
},
|
| 465 |
+
"bfd_uniclust30": {
|
| 466 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 467 |
+
"bfd_database_path": self.bfd_database_path,
|
| 468 |
+
"uniclust30_database_path": self.uniclust30_database_path
|
| 469 |
+
},
|
| 470 |
+
"bfd_uniref30": {
|
| 471 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 472 |
+
"bfd_database_path": self.bfd_database_path,
|
| 473 |
+
"uniref_30_database_path": self.uniref_30_database_path
|
| 474 |
+
},
|
| 475 |
+
|
| 476 |
+
"mgnify": {
|
| 477 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 478 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 479 |
+
},
|
| 480 |
+
"uniprot": {
|
| 481 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 482 |
+
"uniprot_database_path": self.uniprot_database_path,
|
| 483 |
+
},
|
| 484 |
+
###################### RNA ########################
|
| 485 |
+
"rfam": {
|
| 486 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 487 |
+
"rfam_database_path": self.rfam_database_path,
|
| 488 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 489 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 490 |
+
},
|
| 491 |
+
"rnacentral": {
|
| 492 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 493 |
+
"rnacentral_database_path": self.rnacentral_database_path,
|
| 494 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 495 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 496 |
+
},
|
| 497 |
+
"nt": {
|
| 498 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 499 |
+
"nt_database_path": self.nt_database_path,
|
| 500 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 501 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 502 |
+
},
|
| 503 |
+
|
| 504 |
+
###################################################
|
| 505 |
+
"alphafold2": {
|
| 506 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 507 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 508 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 509 |
+
"bfd_database_path": self.bfd_database_path,
|
| 510 |
+
"uniclust30_database_path": self.uniclust30_database_path,
|
| 511 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 512 |
+
},
|
| 513 |
+
"alphafold2_multimer": {
|
| 514 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 515 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 516 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 517 |
+
"bfd_database_path": self.bfd_database_path,
|
| 518 |
+
"uniref_30_database_path": self.uniref_30_database_path,
|
| 519 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 520 |
+
"uniprot_database_path": self.uniprot_database_path,
|
| 521 |
+
},
|
| 522 |
+
"alphafold3": {
|
| 523 |
+
"jackhmmer_binary_path": self.jackhmmer_binary_path,
|
| 524 |
+
"hhblits_binary_path": self.hhblits_binary_path,
|
| 525 |
+
"template_searcher": self.template_searcher,
|
| 526 |
+
"template_featurizer": self.template_featurizer,
|
| 527 |
+
"uniref90_database_path": self.uniref90_database_path,
|
| 528 |
+
"bfd_database_path": self.bfd_database_path,
|
| 529 |
+
"uniclust30_database_path": self.uniclust30_database_path,
|
| 530 |
+
"mgnify_database_path": self.mgnify_database_path,
|
| 531 |
+
"uniprot_database_path": self.uniprot_database_path,
|
| 532 |
+
},
|
| 533 |
+
|
| 534 |
+
"rna": {
|
| 535 |
+
"nhmmer_binary_path": self.nhmmer_binary_path,
|
| 536 |
+
"rfam_database_path": self.rfam_database_path,
|
| 537 |
+
"rnacentral_database_path": self.rnacentral_database_path,
|
| 538 |
+
"hmmbuild_binary_path": self.hmmbuild_binary_path,
|
| 539 |
+
"hmmalign_binary_path": self.hmmalign_binary_path,
|
| 540 |
+
},
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
def _parse_io_tuples(self, input_fasta_path, output_dir, convert_md5=True, prefix="protein"):
|
| 544 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 545 |
+
if isinstance(input_fasta_path, list):
|
| 546 |
+
input_fasta_paths = input_fasta_path
|
| 547 |
+
elif os.path.isdir(input_fasta_path):
|
| 548 |
+
input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
|
| 549 |
+
elif os.path.isfile(input_fasta_path):
|
| 550 |
+
input_fasta_paths = [input_fasta_path]
|
| 551 |
+
else:
|
| 552 |
+
input_fasta_paths = []
|
| 553 |
+
Exception("Can't parse input fasta path!")
|
| 554 |
+
seqs = [parse_fasta(load_txt(i))[0][0] for i in input_fasta_paths]
|
| 555 |
+
# sequences = [parsers.parse_fasta(load_txt(path))[0][0] for path in input_fasta_paths]
|
| 556 |
+
# TODO: debug
|
| 557 |
+
if convert_md5:
|
| 558 |
+
output_msas_dirs = [os.path.join(output_dir, convert_md5_string(f"{prefix}:{i}")) for i in
|
| 559 |
+
seqs]
|
| 560 |
+
else:
|
| 561 |
+
output_msas_dirs = [os.path.join(output_dir, os.path.split(i)[1].split(".")[0]) for i in input_fasta_paths]
|
| 562 |
+
io_tuples = [(i, o) for i, o in zip(input_fasta_paths, output_msas_dirs)]
|
| 563 |
+
return io_tuples
|
| 564 |
+
|
| 565 |
+
def _process_iotuple(self, io_tuple, msas_type):
|
| 566 |
+
i, o = io_tuple
|
| 567 |
+
alignment_runner = AlignmentRunner(
|
| 568 |
+
**self.runner_args_map[msas_type],
|
| 569 |
+
no_cpus=self.n_cpus
|
| 570 |
+
)
|
| 571 |
+
try:
|
| 572 |
+
alignment_runner.run(i, o)
|
| 573 |
+
except:
|
| 574 |
+
logging.warning(f"{i}:{o} task failed!")
|
| 575 |
+
|
| 576 |
+
def process(self, input_fasta_path, output_dir, msas_type="rfam", convert_md5=True):
|
| 577 |
+
prefix = "rna" if msas_type in ["rfam", "rnacentral", "nt", "rna"] else "protein"
|
| 578 |
+
io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=convert_md5, prefix=prefix)
|
| 579 |
+
run_pool_tasks(partial(self._process_iotuple, msas_type=msas_type), io_tuples, num_workers=self.n_workers,
|
| 580 |
+
return_dict=False)
|
| 581 |
+
|
| 582 |
+
def convert_output_to_md5(self, input_fasta_path, output_dir, md5_output_dir, prefix="protein"):
|
| 583 |
+
io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=False, prefix=prefix)
|
| 584 |
+
io_tuples_md5 = self._parse_io_tuples(input_fasta_path, md5_output_dir, convert_md5=True, prefix=prefix)
|
| 585 |
+
|
| 586 |
+
for io0, io1 in tqdm.tqdm(zip(io_tuples, io_tuples_md5)):
|
| 587 |
+
o, o_md5 = io0[1], io1[1]
|
| 588 |
+
os.system(f"cp -r {os.path.abspath(o)} {os.path.abspath(o_md5)}")
|
PhysDock/data/tools/convert_unifold_template_to_stfold.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
sys.path.append("../")
|
| 8 |
+
|
| 9 |
+
from PhysDock.utils.io_utils import load_pkl, dump_pkl
|
| 10 |
+
from PhysDock.data.tools.residue_constants import \
|
| 11 |
+
hhblits_id_to_standard_residue_id_np, af3_if_to_residue_id
|
| 12 |
+
|
| 13 |
+
def dgram_from_positions(
|
| 14 |
+
pos: torch.Tensor,
|
| 15 |
+
min_bin: float = 3.25,
|
| 16 |
+
max_bin: float = 50.75,
|
| 17 |
+
no_bins: float = 39,
|
| 18 |
+
inf: float = 1e8,
|
| 19 |
+
):
|
| 20 |
+
dgram = torch.sum(
|
| 21 |
+
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
|
| 22 |
+
)
|
| 23 |
+
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
|
| 24 |
+
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
|
| 25 |
+
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
|
| 26 |
+
return dgram
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def convert_unifold_template_feature_to_stfold_unifold_feature(unifold_template_feature):
|
| 30 |
+
try:
|
| 31 |
+
print(unifold_template_feature)
|
| 32 |
+
md5_string = os.path.basename(unifold_template_feature)[:-6]
|
| 33 |
+
out_path = os.path.dirname(unifold_template_feature)
|
| 34 |
+
out_path = os.path.dirname(out_path)
|
| 35 |
+
unifold_template_feature = os.path.join(out_path, "msas", md5_string)
|
| 36 |
+
out_path_final = os.path.join(out_path, "msas", "template_features")
|
| 37 |
+
final = os.path.join(out_path_final, f"{md5_string}.pkl.gz")
|
| 38 |
+
if os.path.exists(final):
|
| 39 |
+
return dict()
|
| 40 |
+
|
| 41 |
+
unifold_template_feature = os.path.join(unifold_template_feature, "templates", "pdb_hits.hhr.pkl.gz")
|
| 42 |
+
if isinstance(unifold_template_feature, str):
|
| 43 |
+
data = load_pkl(unifold_template_feature)
|
| 44 |
+
else:
|
| 45 |
+
data = unifold_template_feature
|
| 46 |
+
template_restype = af3_if_to_residue_id[
|
| 47 |
+
hhblits_id_to_standard_residue_id_np[np.argmax(data["template_aatype"], axis=-1)]]
|
| 48 |
+
assert np.all(template_restype != -1)
|
| 49 |
+
assert len(template_restype) >= 1
|
| 50 |
+
assert len(template_restype[0, :]) >= 4
|
| 51 |
+
# shape = template_restype.shape
|
| 52 |
+
# template_restype=template_restype.view([-1])[].view(shape)
|
| 53 |
+
|
| 54 |
+
bb_x_gt = torch.from_numpy(data["template_all_atom_positions"][..., :3, :])
|
| 55 |
+
bb_x_mask = torch.from_numpy(data["template_all_atom_masks"][..., :3])
|
| 56 |
+
|
| 57 |
+
bb_x_gt_beta1 = data["template_all_atom_positions"][..., 3, :]
|
| 58 |
+
bb_x_gt_beta_mask1 = data["template_all_atom_masks"][..., 3]
|
| 59 |
+
bb_x_gt_beta2 = data["template_all_atom_positions"][..., 1, :]
|
| 60 |
+
bb_x_gt_beta_mask2 = data["template_all_atom_masks"][..., 1]
|
| 61 |
+
|
| 62 |
+
is_gly = template_restype == 7
|
| 63 |
+
template_pseudo_beta = np.where(is_gly[..., None], bb_x_gt_beta2, bb_x_gt_beta1)
|
| 64 |
+
template_pseudo_beta_mask = np.where(is_gly, bb_x_gt_beta_mask2, bb_x_gt_beta_mask1)
|
| 65 |
+
template_backbone_frame_mask = bb_x_mask[..., 0] * bb_x_mask[..., 1] * bb_x_mask[..., 2]
|
| 66 |
+
out = {
|
| 67 |
+
"template_restype": template_restype.astype(np.int8),
|
| 68 |
+
"template_backbone_frame_mask": template_backbone_frame_mask.numpy().astype(np.int8),
|
| 69 |
+
"template_backbone_frame": bb_x_gt.numpy().astype(np.float32),
|
| 70 |
+
"template_pseudo_beta": template_pseudo_beta.astype(np.float32),
|
| 71 |
+
"template_pseudo_beta_mask": template_pseudo_beta_mask.astype(np.int8),
|
| 72 |
+
# "template_backbone_mask": template_mask.numpy().astype(np.int8),
|
| 73 |
+
}
|
| 74 |
+
# for k,v in out.items():
|
| 75 |
+
# print(k,v.shape)
|
| 76 |
+
dump_pkl(out, os.path.join(out_path_final, f"{md5_string}.pkl.gz"), compress=True)
|
| 77 |
+
except:
|
| 78 |
+
pass
|
| 79 |
+
out = dict()
|
| 80 |
+
print(f"dump templ feats to {md5_string}.pkl.gz")
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# HHBLITS_ID_TO_AA = {
|
| 85 |
+
# 0: "ALA",
|
| 86 |
+
# 1: "CYS", # Also U.
|
| 87 |
+
# 2: "ASP", # Also B.
|
| 88 |
+
# 3: "GLU", # Also Z.
|
| 89 |
+
# 4: "PHE",
|
| 90 |
+
# 5: "GLY",
|
| 91 |
+
# 6: "HIS",
|
| 92 |
+
# 7: "ILE",
|
| 93 |
+
# 8: "LYS",
|
| 94 |
+
# 9: "LEU",
|
| 95 |
+
# 10: "MET",
|
| 96 |
+
# 11: "ASN",
|
| 97 |
+
# 12: "PRO",
|
| 98 |
+
# 13: "GLN",
|
| 99 |
+
# 14: "ARG",
|
| 100 |
+
# 15: "SER",
|
| 101 |
+
# 16: "THR",
|
| 102 |
+
# 17: "VAL",
|
| 103 |
+
# 18: "TRP",
|
| 104 |
+
# 19: "TYR",
|
| 105 |
+
# 20: "UNK", # Includes J and O.
|
| 106 |
+
# 21: "GAP",
|
| 107 |
+
# }
|
| 108 |
+
#
|
| 109 |
+
# # Usage: Convert hhblits msa to af3 aatype
|
| 110 |
+
# # msa = hhblits_id_to_standard_residue_id_np[hhblits_msa.astype(np.int64)]
|
| 111 |
+
# hhblits_id_to_standard_residue_id_np = np.array(
|
| 112 |
+
# [standard_ccds.index(ccd) for id, ccd in HHBLITS_ID_TO_AA.items()]
|
| 113 |
+
# )
|
| 114 |
+
#
|
| 115 |
+
# of_restypes = [
|
| 116 |
+
# "A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
|
| 117 |
+
# "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X", "-"
|
| 118 |
+
# ]
|
| 119 |
+
#
|
| 120 |
+
# af3_restypes = [amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "-" if ccd == "GAP" else "None" for ccd in
|
| 121 |
+
# standard_ccds
|
| 122 |
+
# ]
|
| 123 |
+
#
|
| 124 |
+
# af3_if_to_residue_id = np.array(
|
| 125 |
+
# [af3_restypes.index(restype) if restype in of_restypes else -1 for restype in af3_restypes])
|
| 126 |
+
|
| 127 |
+
|
PhysDock/data/tools/dataset_manager.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from functools import partial
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from PhysDock.utils.io_utils import run_pool_tasks, load_json, load_txt, dump_pkl, \
|
| 7 |
+
convert_md5_string
|
| 8 |
+
from PhysDock.data.tools.parsers import parse_fasta
|
| 9 |
+
from PhysDock.data.tools.parse_msas import parse_protein_alignment_dir, parse_uniprot_alignment_dir, \
|
| 10 |
+
parse_rna_alignment_dir
|
| 11 |
+
from PhysDock.data.alignment_runner import DataProcessor
|
| 12 |
+
from PhysDock.data.tools.residue_constants import standard_ccds, amino_acid_3to1
|
| 13 |
+
from PhysDock.data.tools.PDBData import protein_letters_3to1_extended, nucleic_letters_3to1_extended
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_protein_md5(sequence_3):
|
| 17 |
+
ccds = sequence_3.split("-")
|
| 18 |
+
|
| 19 |
+
start = 0
|
| 20 |
+
end = 0
|
| 21 |
+
for i, ccd in enumerate(ccds):
|
| 22 |
+
if ccd not in ["UNK", "N ", "DN ", "GAP"]:
|
| 23 |
+
start = i
|
| 24 |
+
break
|
| 25 |
+
for i, ccd in enumerate(ccds[::-1]):
|
| 26 |
+
if ccd not in ["UNK", "N ", "DN ", "GAP"]:
|
| 27 |
+
end = i
|
| 28 |
+
break
|
| 29 |
+
# print(start,end)
|
| 30 |
+
ccds_strip_unk = ccds[start:-end] if end > 0 else ccds[start:]
|
| 31 |
+
sequence_0 = "".join(
|
| 32 |
+
[protein_letters_3to1_extended[ccd] if ccd in protein_letters_3to1_extended else "X" for ccd in ccds])
|
| 33 |
+
sequence_1 = "".join([amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "X" for ccd in ccds])
|
| 34 |
+
sequence_2 = "".join(
|
| 35 |
+
[protein_letters_3to1_extended[ccd] if ccd in protein_letters_3to1_extended else "X" for ccd in ccds_strip_unk])
|
| 36 |
+
sequence_3 = "".join([amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "X" for ccd in ccds_strip_unk])
|
| 37 |
+
|
| 38 |
+
sequences = []
|
| 39 |
+
for sequence in [sequence_0, sequence_1, sequence_2, sequence_3]:
|
| 40 |
+
if sequence not in sequences:
|
| 41 |
+
sequences.append(sequence)
|
| 42 |
+
return sequences, [convert_md5_string(f"protein:{i}") for i in sequences]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_rna_md5(sequence_3):
|
| 46 |
+
ccds = sequence_3.split("-")
|
| 47 |
+
chs = [nucleic_letters_3to1_extended[ccd] if ccd not in ["UNK", "GAP", "N ", "DN "] else "N" for ccd in ccds]
|
| 48 |
+
sequence = "".join(chs)
|
| 49 |
+
md5 = convert_md5_string(f"rna:{sequence}")
|
| 50 |
+
return sequence, md5
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class DatasetManager:
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
dataset_path
|
| 57 |
+
):
|
| 58 |
+
self.dataset_path = dataset_path
|
| 59 |
+
# Meta Data
|
| 60 |
+
self.chain_id_to_meta_info = load_json(os.path.join(dataset_path, "chain_id_to_meta_info.json"))
|
| 61 |
+
self.pdb_id_to_meta_info = load_json(os.path.join(dataset_path, "pdb_id_to_meta_info.json"))
|
| 62 |
+
self.ccd_id_to_meta_info = load_json(os.path.join(dataset_path, "ccd_id_to_meta_info.json"))
|
| 63 |
+
|
| 64 |
+
# filtering chains
|
| 65 |
+
# self.train_polymer_chain_ids = load_json()
|
| 66 |
+
# self.validation_polymer_chain_ids = load_json()
|
| 67 |
+
|
| 68 |
+
# def check_protein_msa_features_completeness(input_fasta_path):
|
| 69 |
+
# pass
|
| 70 |
+
#
|
| 71 |
+
# def check_protein_uniprot_msa_features_completeness(self, chain_ids, num_workers):
|
| 72 |
+
# def _run(chain_id):
|
| 73 |
+
# if chain_id not in self.chain_id_to_meta_info:
|
| 74 |
+
# return {
|
| 75 |
+
# "chain_id": f"Not find this chain {chain_id}"
|
| 76 |
+
# }
|
| 77 |
+
# sequence_3 = self.chain_id_to_meta_info[chain_id]["sequence_3"]
|
| 78 |
+
#
|
| 79 |
+
# out = run_pool_tasks(_run, chain_ids, return_dict=True, num_workers=num_workers)
|
| 80 |
+
# return out
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def homo_search(
|
| 84 |
+
input_fasta_path,
|
| 85 |
+
output_dir,
|
| 86 |
+
msas_type,
|
| 87 |
+
convert_md5,
|
| 88 |
+
alphafold3_database_path,
|
| 89 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 90 |
+
hhblits_binary_path: Optional[str] = None,
|
| 91 |
+
nhmmer_binary_path: Optional[str] = None,
|
| 92 |
+
kalign_binary_path: Optional[str] = None,
|
| 93 |
+
hmmbuild_binary_path: Optional[str] = None,
|
| 94 |
+
hmmalign_binary_path: Optional[str] = None,
|
| 95 |
+
n_cpus: int = 8,
|
| 96 |
+
n_workers: int = 1,
|
| 97 |
+
):
|
| 98 |
+
data_processor = DataProcessor(
|
| 99 |
+
alphafold3_database_path=alphafold3_database_path,
|
| 100 |
+
jackhmmer_binary_path=jackhmmer_binary_path,
|
| 101 |
+
hhblits_binary_path=hhblits_binary_path,
|
| 102 |
+
nhmmer_binary_path=nhmmer_binary_path,
|
| 103 |
+
kalign_binary_path=kalign_binary_path,
|
| 104 |
+
hmmbuild_binary_path=hmmbuild_binary_path,
|
| 105 |
+
hmmalign_binary_path=hmmalign_binary_path,
|
| 106 |
+
n_cpus=n_cpus,
|
| 107 |
+
n_workers=n_workers,
|
| 108 |
+
)
|
| 109 |
+
data_processor.process(
|
| 110 |
+
input_fasta_path=input_fasta_path,
|
| 111 |
+
output_dir=output_dir,
|
| 112 |
+
msas_type=msas_type,
|
| 113 |
+
# msas_type="bfd_uniclust30", # alphafold3 # rna
|
| 114 |
+
# msas_type="uniprot", # alphafold3 # rna
|
| 115 |
+
convert_md5=convert_md5
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def get_unsearched_input_fasta_path(input_fasta_path, output_dir, msas_type, convert_md5, num_workers=128):
|
| 120 |
+
if isinstance(input_fasta_path, list):
|
| 121 |
+
input_fasta_paths = input_fasta_path
|
| 122 |
+
elif os.path.isdir(input_fasta_path):
|
| 123 |
+
input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
|
| 124 |
+
elif os.path.isfile(input_fasta_path):
|
| 125 |
+
input_fasta_paths = [input_fasta_path]
|
| 126 |
+
else:
|
| 127 |
+
input_fasta_paths = []
|
| 128 |
+
Exception("Can't parse input fasta path!")
|
| 129 |
+
|
| 130 |
+
prefix = {
|
| 131 |
+
"uniref90": "protein",
|
| 132 |
+
"bfd_uniclust30": "protein",
|
| 133 |
+
"bfd_uniref30": "protein",
|
| 134 |
+
"uniprot": "protein",
|
| 135 |
+
"mgnify": "protein",
|
| 136 |
+
"rfam": "rna",
|
| 137 |
+
"rnacentral": "rna",
|
| 138 |
+
"nt": "rna",
|
| 139 |
+
}[msas_type]
|
| 140 |
+
global _get_unsearched_input_fasta_path
|
| 141 |
+
|
| 142 |
+
def _get_unsearched_input_fasta_path(input_fasta_path, convert_md5, prefix, output_dir, msas_type):
|
| 143 |
+
# TODO
|
| 144 |
+
# seqs, decs = parse_fasta(i)
|
| 145 |
+
if convert_md5:
|
| 146 |
+
# dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
|
| 147 |
+
dec = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 148 |
+
else:
|
| 149 |
+
dec = os.path.split(input_fasta_path)[1].split(".")[0]
|
| 150 |
+
|
| 151 |
+
if os.path.exists(os.path.join(output_dir, dec, f"{msas_type}_hits.sto")) or \
|
| 152 |
+
os.path.exists(os.path.join(output_dir, dec, f"{msas_type}_hits.a3m")):
|
| 153 |
+
return dict()
|
| 154 |
+
else:
|
| 155 |
+
return {input_fasta_path: False}
|
| 156 |
+
|
| 157 |
+
out = run_pool_tasks(partial(
|
| 158 |
+
_get_unsearched_input_fasta_path,
|
| 159 |
+
convert_md5=convert_md5,
|
| 160 |
+
prefix=prefix,
|
| 161 |
+
output_dir=output_dir,
|
| 162 |
+
msas_type=msas_type,
|
| 163 |
+
), input_fasta_paths, num_workers=num_workers, return_dict=True)
|
| 164 |
+
return list(out.keys())
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def convert_msas_out_to_msa_features(
|
| 168 |
+
input_fasta_path,
|
| 169 |
+
output_dir,
|
| 170 |
+
msa_feature_dir,
|
| 171 |
+
convert_md5=True,
|
| 172 |
+
num_workers=128
|
| 173 |
+
):
|
| 174 |
+
if isinstance(input_fasta_path, list):
|
| 175 |
+
input_fasta_paths = input_fasta_path
|
| 176 |
+
elif os.path.isdir(input_fasta_path):
|
| 177 |
+
input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
|
| 178 |
+
elif os.path.isfile(input_fasta_path):
|
| 179 |
+
input_fasta_paths = [input_fasta_path]
|
| 180 |
+
else:
|
| 181 |
+
input_fasta_paths = []
|
| 182 |
+
Exception("Can't parse input fasta path!")
|
| 183 |
+
|
| 184 |
+
global _convert_msas_out_to_msa_features
|
| 185 |
+
|
| 186 |
+
def _convert_msas_out_to_msa_features(
|
| 187 |
+
input_fasta_path,
|
| 188 |
+
output_dir,
|
| 189 |
+
msa_feature_dir,
|
| 190 |
+
convert_md5=True,
|
| 191 |
+
):
|
| 192 |
+
prefix = "protein"
|
| 193 |
+
max_seq = 16384
|
| 194 |
+
seqs, decs = parse_fasta(load_txt(input_fasta_path))
|
| 195 |
+
md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 196 |
+
if convert_md5:
|
| 197 |
+
# TODO: debug
|
| 198 |
+
# dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
|
| 199 |
+
dec = md5
|
| 200 |
+
else:
|
| 201 |
+
dec = os.path.split(input_fasta_path)[1].split(".")[0]
|
| 202 |
+
|
| 203 |
+
# DEBUG: whc homo search hits
|
| 204 |
+
if os.path.exists(os.path.join(output_dir, dec, "msas")):
|
| 205 |
+
dec = dec + "/msas"
|
| 206 |
+
|
| 207 |
+
pkl_save_path = os.path.join(msa_feature_dir, f"{md5}.pkl.gz")
|
| 208 |
+
|
| 209 |
+
if os.path.exists(pkl_save_path):
|
| 210 |
+
return dict()
|
| 211 |
+
if os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
|
| 212 |
+
os.path.exists(os.path.join(output_dir, dec, "bfd_uniclust30_hits.a3m")) and \
|
| 213 |
+
os.path.exists(os.path.join(output_dir, dec, "mgnify_hits.sto")):
|
| 214 |
+
msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
|
| 215 |
+
|
| 216 |
+
sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
|
| 217 |
+
md5_string = convert_md5_string(f"protein:{sequence}")
|
| 218 |
+
if md5 == md5_string:
|
| 219 |
+
feature = {
|
| 220 |
+
"msa": msa_feature["msa"][:max_seq].astype(np.int8),
|
| 221 |
+
"deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
|
| 222 |
+
"msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
|
| 223 |
+
}
|
| 224 |
+
dump_pkl(feature, pkl_save_path)
|
| 225 |
+
return dict()
|
| 226 |
+
else:
|
| 227 |
+
return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
|
| 228 |
+
|
| 229 |
+
# DEBUG: whc
|
| 230 |
+
elif os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
|
| 231 |
+
os.path.exists(os.path.join(output_dir, dec, "bfd_uniref_hits.a3m")) and \
|
| 232 |
+
os.path.exists(os.path.join(output_dir, dec, "mgnify_hits.sto")):
|
| 233 |
+
msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
|
| 234 |
+
|
| 235 |
+
sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
|
| 236 |
+
md5_string = convert_md5_string(f"protein:{sequence}")
|
| 237 |
+
if md5 == md5_string:
|
| 238 |
+
feature = {
|
| 239 |
+
"msa": msa_feature["msa"][:max_seq].astype(np.int8),
|
| 240 |
+
"deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
|
| 241 |
+
"msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
|
| 242 |
+
}
|
| 243 |
+
dump_pkl(feature, pkl_save_path)
|
| 244 |
+
return dict()
|
| 245 |
+
else:
|
| 246 |
+
return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
|
| 247 |
+
|
| 248 |
+
elif os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
|
| 249 |
+
os.path.exists(os.path.join(output_dir, dec, "bfd_uniclust30_hits.a3m")):
|
| 250 |
+
msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
|
| 251 |
+
if len(msa_feature["msa"]) < max_seq:
|
| 252 |
+
return {
|
| 253 |
+
input_fasta_path: f"MSA is not enough!"
|
| 254 |
+
}
|
| 255 |
+
sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
|
| 256 |
+
md5_string = convert_md5_string(f"protein:{sequence}")
|
| 257 |
+
if md5 == md5_string:
|
| 258 |
+
feature = {
|
| 259 |
+
"msa": msa_feature["msa"][:max_seq].astype(np.int8),
|
| 260 |
+
"deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
|
| 261 |
+
"msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
|
| 262 |
+
}
|
| 263 |
+
dump_pkl(feature, pkl_save_path)
|
| 264 |
+
return dict()
|
| 265 |
+
else:
|
| 266 |
+
return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
|
| 267 |
+
elif os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
|
| 268 |
+
os.path.exists(os.path.join(output_dir, dec, "mgnify_hits.sto")):
|
| 269 |
+
msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
|
| 270 |
+
|
| 271 |
+
sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
|
| 272 |
+
md5_string = convert_md5_string(f"protein:{sequence}")
|
| 273 |
+
if md5 == md5_string:
|
| 274 |
+
feature = {
|
| 275 |
+
"msa": msa_feature["msa"][:max_seq].astype(np.int8),
|
| 276 |
+
"deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
|
| 277 |
+
"msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
|
| 278 |
+
}
|
| 279 |
+
dump_pkl(feature, pkl_save_path)
|
| 280 |
+
return dict()
|
| 281 |
+
else:
|
| 282 |
+
return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
|
| 283 |
+
|
| 284 |
+
else:
|
| 285 |
+
# msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
|
| 286 |
+
#
|
| 287 |
+
# sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
|
| 288 |
+
# md5_string = convert_md5_string(f"protein:{sequence}")
|
| 289 |
+
# if md5 == md5_string:
|
| 290 |
+
# feature = {
|
| 291 |
+
# "msa": msa_feature["msa"][:max_seq].astype(np.int8),
|
| 292 |
+
# "deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
|
| 293 |
+
# "msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
|
| 294 |
+
# }
|
| 295 |
+
# dump_pkl(feature, pkl_save_path)
|
| 296 |
+
# return dict()
|
| 297 |
+
# else:
|
| 298 |
+
# return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
|
| 299 |
+
|
| 300 |
+
return {
|
| 301 |
+
input_fasta_path: f"MSA is not enough!"
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
out = run_pool_tasks(partial(
|
| 305 |
+
_convert_msas_out_to_msa_features,
|
| 306 |
+
output_dir=output_dir,
|
| 307 |
+
msa_feature_dir=msa_feature_dir,
|
| 308 |
+
convert_md5=convert_md5
|
| 309 |
+
), input_fasta_paths, num_workers=num_workers, return_dict=True)
|
| 310 |
+
return out
|
| 311 |
+
|
| 312 |
+
@staticmethod
|
| 313 |
+
def convert_msas_out_to_uniprot_msa_features(
|
| 314 |
+
input_fasta_path,
|
| 315 |
+
output_dir,
|
| 316 |
+
uniprot_msa_feature_dir,
|
| 317 |
+
convert_md5=True,
|
| 318 |
+
num_workers=128
|
| 319 |
+
):
|
| 320 |
+
if isinstance(input_fasta_path, list):
|
| 321 |
+
input_fasta_paths = input_fasta_path
|
| 322 |
+
elif os.path.isdir(input_fasta_path):
|
| 323 |
+
input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
|
| 324 |
+
elif os.path.isfile(input_fasta_path):
|
| 325 |
+
input_fasta_paths = [input_fasta_path]
|
| 326 |
+
else:
|
| 327 |
+
input_fasta_paths = []
|
| 328 |
+
Exception("Can't parse input fasta path!")
|
| 329 |
+
|
| 330 |
+
global _convert_msas_out_to_uniprot_msa_features
|
| 331 |
+
|
| 332 |
+
def _convert_msas_out_to_uniprot_msa_features(
|
| 333 |
+
input_fasta_path,
|
| 334 |
+
output_dir,
|
| 335 |
+
uniprot_msa_feature_dir,
|
| 336 |
+
convert_md5=True,
|
| 337 |
+
):
|
| 338 |
+
prefix = "protein"
|
| 339 |
+
max_seq = 50000
|
| 340 |
+
seqs, decs = parse_fasta(load_txt(input_fasta_path))
|
| 341 |
+
md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 342 |
+
if convert_md5:
|
| 343 |
+
# TODO: debug
|
| 344 |
+
# dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
|
| 345 |
+
dec = md5
|
| 346 |
+
else:
|
| 347 |
+
dec = os.path.split(input_fasta_path)[1].split(".")[0]
|
| 348 |
+
|
| 349 |
+
pkl_save_path = os.path.join(uniprot_msa_feature_dir, f"{md5}.pkl.gz")
|
| 350 |
+
|
| 351 |
+
if os.path.exists(pkl_save_path):
|
| 352 |
+
return dict()
|
| 353 |
+
if os.path.exists(os.path.join(output_dir, dec, "uniprot_hits.sto")):
|
| 354 |
+
msa_feature = parse_uniprot_alignment_dir(os.path.join(output_dir, dec))
|
| 355 |
+
|
| 356 |
+
sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa_all_seq"][0]])
|
| 357 |
+
md5_string = convert_md5_string(f"protein:{sequence}")
|
| 358 |
+
if md5 == md5_string:
|
| 359 |
+
feature = {
|
| 360 |
+
"msa_all_seq": msa_feature["msa_all_seq"][:max_seq].astype(np.int8),
|
| 361 |
+
"deletion_matrix_all_seq": msa_feature["deletion_matrix_all_seq"][:max_seq].astype(np.int8),
|
| 362 |
+
"msa_species_identifiers_all_seq": msa_feature["msa_species_identifiers_all_seq"][:max_seq]
|
| 363 |
+
}
|
| 364 |
+
dump_pkl(feature, pkl_save_path)
|
| 365 |
+
return dict()
|
| 366 |
+
else:
|
| 367 |
+
return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
|
| 368 |
+
|
| 369 |
+
else:
|
| 370 |
+
return {
|
| 371 |
+
input_fasta_path: f"MSA is not enough!"
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
out = run_pool_tasks(partial(
|
| 375 |
+
_convert_msas_out_to_uniprot_msa_features,
|
| 376 |
+
output_dir=output_dir,
|
| 377 |
+
uniprot_msa_feature_dir=uniprot_msa_feature_dir,
|
| 378 |
+
convert_md5=convert_md5
|
| 379 |
+
), input_fasta_paths, num_workers=num_workers, return_dict=True)
|
| 380 |
+
return out
|
| 381 |
+
|
| 382 |
+
@staticmethod
|
| 383 |
+
def convert_msas_out_to_rna_msa_features(
|
| 384 |
+
input_fasta_path,
|
| 385 |
+
output_dir,
|
| 386 |
+
rna_msa_feature_dir,
|
| 387 |
+
convert_md5=True,
|
| 388 |
+
num_workers=128
|
| 389 |
+
):
|
| 390 |
+
import os
|
| 391 |
+
os.makedirs(rna_msa_feature_dir, exist_ok=True)
|
| 392 |
+
if isinstance(input_fasta_path, list):
|
| 393 |
+
input_fasta_paths = input_fasta_path
|
| 394 |
+
elif os.path.isdir(input_fasta_path):
|
| 395 |
+
input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
|
| 396 |
+
elif os.path.isfile(input_fasta_path):
|
| 397 |
+
input_fasta_paths = [input_fasta_path]
|
| 398 |
+
else:
|
| 399 |
+
input_fasta_paths = []
|
| 400 |
+
Exception("Can't parse input fasta path!")
|
| 401 |
+
|
| 402 |
+
global _convert_msas_out_to_rna_msa_features
|
| 403 |
+
|
| 404 |
+
def _convert_msas_out_to_rna_msa_features(
|
| 405 |
+
input_fasta_path,
|
| 406 |
+
output_dir,
|
| 407 |
+
rna_msa_feature_dir,
|
| 408 |
+
convert_md5=True,
|
| 409 |
+
):
|
| 410 |
+
prefix = "rna"
|
| 411 |
+
max_seq = 16384
|
| 412 |
+
seqs, decs = parse_fasta(load_txt(input_fasta_path))
|
| 413 |
+
md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
|
| 414 |
+
if convert_md5:
|
| 415 |
+
# TODO: debug
|
| 416 |
+
# dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
|
| 417 |
+
dec = md5
|
| 418 |
+
else:
|
| 419 |
+
dec = os.path.split(input_fasta_path)[1].split(".")[0]
|
| 420 |
+
|
| 421 |
+
# DEBUG: whc homo search hits
|
| 422 |
+
if os.path.exists(os.path.join(output_dir, dec, "msas")):
|
| 423 |
+
dec = dec + "/msas"
|
| 424 |
+
|
| 425 |
+
pkl_save_path = os.path.join(rna_msa_feature_dir, f"{md5}.pkl.gz")
|
| 426 |
+
|
| 427 |
+
if os.path.exists(pkl_save_path):
|
| 428 |
+
return dict()
|
| 429 |
+
rna_msa_feature = parse_rna_alignment_dir(
|
| 430 |
+
os.path.join(output_dir, dec),
|
| 431 |
+
input_fasta_path
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
feature = {
|
| 435 |
+
"msa": rna_msa_feature["msa"][:max_seq].astype(np.int8),
|
| 436 |
+
"deletion_matrix": rna_msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
|
| 437 |
+
"msa_species_identifiers": None
|
| 438 |
+
}
|
| 439 |
+
dump_pkl(feature, pkl_save_path)
|
| 440 |
+
|
| 441 |
+
return dict()
|
| 442 |
+
|
| 443 |
+
out = run_pool_tasks(partial(
|
| 444 |
+
_convert_msas_out_to_rna_msa_features,
|
| 445 |
+
output_dir=output_dir,
|
| 446 |
+
rna_msa_feature_dir=rna_msa_feature_dir,
|
| 447 |
+
convert_md5=convert_md5
|
| 448 |
+
), input_fasta_paths, num_workers=num_workers, return_dict=True)
|
| 449 |
+
return out
|
| 450 |
+
|
| 451 |
+
@staticmethod
|
| 452 |
+
def find_chain_ids_without_msa_features(
|
| 453 |
+
polymer_filtering_out_json,
|
| 454 |
+
chain_id_to_meta_info_path,
|
| 455 |
+
dataset_dir,
|
| 456 |
+
uniprot=False,
|
| 457 |
+
num_workers=256,
|
| 458 |
+
):
|
| 459 |
+
if not isinstance(polymer_filtering_out_json, list):
|
| 460 |
+
polymer_filtering_out_json = [polymer_filtering_out_json]
|
| 461 |
+
polymer_filtering_out = dict()
|
| 462 |
+
for i in polymer_filtering_out_json:
|
| 463 |
+
polymer_filtering_out.update(load_json(i))
|
| 464 |
+
chain_id_to_meta_info = load_json(chain_id_to_meta_info_path)
|
| 465 |
+
global find_chain_ids_without_msa_features
|
| 466 |
+
|
| 467 |
+
def find_chain_ids_without_msa_features(chain_id):
|
| 468 |
+
sequence_3 = chain_id_to_meta_info[chain_id]["sequence_3"]
|
| 469 |
+
seqs, md5s = get_protein_md5(sequence_3)
|
| 470 |
+
if uniprot:
|
| 471 |
+
dirs = ["uniprot_msa_features", "uniprot_msa_features_zkx", "uniprot_msa_features_unifold"]
|
| 472 |
+
else:
|
| 473 |
+
dirs = ["msa_features", "msa_features_zkx", "msa_features_whc", "msa_features_unifold"]
|
| 474 |
+
md5_dir = [[md5, dir] for dir in dirs for md5 in md5s]
|
| 475 |
+
for md5, dir in md5_dir:
|
| 476 |
+
if os.path.exists(os.path.join(dataset_dir, "features/", dir, f"{md5}.pkl.gz")):
|
| 477 |
+
return dict()
|
| 478 |
+
return {chain_id: {"state": False, "seqs": seqs}}
|
| 479 |
+
|
| 480 |
+
chain_ids = [k for k, v in polymer_filtering_out.items() if
|
| 481 |
+
v["state"] and chain_id_to_meta_info[k]["chain_class"] == "protein"]
|
| 482 |
+
|
| 483 |
+
out = run_pool_tasks(
|
| 484 |
+
find_chain_ids_without_msa_features, chain_ids, num_workers=num_workers, return_dict=True)
|
| 485 |
+
return out
|
| 486 |
+
|
| 487 |
+
def find_chain_ids_without_rna_msa_features(
|
| 488 |
+
self,
|
| 489 |
+
polymer_filtering_out_json,
|
| 490 |
+
):
|
| 491 |
+
pass
|
| 492 |
+
|
| 493 |
+
@staticmethod
|
| 494 |
+
def check_msa_md5(msa_feature_dir):
|
| 495 |
+
pass
|
| 496 |
+
|
| 497 |
+
@staticmethod
|
| 498 |
+
def check_uniprot_msa_md5(uniprot_msa_feature_dir):
|
| 499 |
+
pass
|
| 500 |
+
|
| 501 |
+
@staticmethod
|
| 502 |
+
def check_rna_msa_md5(rna_msa_feature_dir):
|
| 503 |
+
pass
|
| 504 |
+
|
| 505 |
+
def get_training_pdbs(self):
|
| 506 |
+
pass
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class DataPipeline():
|
| 516 |
+
def __init__(self):
|
| 517 |
+
super().__init__()
|
| 518 |
+
self.data_manager = DatasetManager()
|
| 519 |
+
|
| 520 |
+
# PDB:
|
| 521 |
+
# polymer_chain_id:
|
| 522 |
+
# weight_chain: contiguous_crop:1/3 spatial_crop: 2/3
|
| 523 |
+
# Interface
|
| 524 |
+
# weight_interface:
|
| 525 |
+
# [chain_id1, chain_id2] 0.2 contiguous_crop
|
| 526 |
+
# [chain_id1, chain_id2] 0.4 spatial_crop_interface
|
| 527 |
+
# [ < 20 chains] 0.4 spatial crop
|
| 528 |
+
#
|
| 529 |
+
#
|
| 530 |
+
#
|
| 531 |
+
#
|
| 532 |
+
# polymer chain contiguous crop sample weight w_chain*1/3 [chain_id]
|
| 533 |
+
# polymer chain spatial crop sample weight w_chain*2/3 [chain_id]
|
| 534 |
+
#
|
| 535 |
+
# interface contiguous crop sample weight w_interface * 0.2 [chain_id, chain_id]
|
| 536 |
+
# interface spatial crop sample weight w_interface * 0.4 >[chain_id, chain_id]
|
| 537 |
+
# interface spatial crop interface sample weight w_interface * 0.4 [chain_id, chain_id]
|
| 538 |
+
#
|
| 539 |
+
#
|
| 540 |
+
# pdb:
|
| 541 |
+
# chain:
|
| 542 |
+
# [chain_id]: 0.14
|
| 543 |
+
# [chain_id]: 0.23
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# @staticmethod
|
| 547 |
+
# def get_pdb_info(pdb_id):
|
| 548 |
+
# all_chain_ids = pdb_id_to_meta_info[pdb_id]["chain_ids"]
|
| 549 |
+
#
|
| 550 |
+
# chain_ids_info = {
|
| 551 |
+
# "protein": [],
|
| 552 |
+
# "rna": [],
|
| 553 |
+
# "dna": [],
|
| 554 |
+
# "ligand": []
|
| 555 |
+
# }
|
| 556 |
+
# for chain_id_ in all_chain_ids:
|
| 557 |
+
# chain_id = f"{pdb_id}_{chain_id_}"
|
| 558 |
+
# if chain_id in chain_id_to_meta_info:
|
| 559 |
+
# chain_class = chain_id_to_meta_info[chain_id]["chain_class"].split("_")[0]
|
| 560 |
+
# if chain_id in chain_ids and os.path.exists(os.path.join(stfold_data_path, f"{chain_id}.pkl.gz")):
|
| 561 |
+
# if chain_class == "protein":
|
| 562 |
+
# if check_protein_msa_features(chain_id, chain_id_to_meta_info)[chain_id]["state"]:
|
| 563 |
+
# chain_ids_info[chain_class].append(chain_id)
|
| 564 |
+
# elif chain_class == "rna":
|
| 565 |
+
# if check_rna_msa_features(chain_id, chain_id_to_meta_info)[chain_id]["state"]:
|
| 566 |
+
# chain_ids_info[chain_class].append(chain_id)
|
| 567 |
+
#
|
| 568 |
+
# elif chain_class == "ligand" and os.path.exists(os.path.join(stfold_data_path, f"{chain_id}.pkl.gz")):
|
| 569 |
+
# chain_ids_info[chain_class].append(chain_id)
|
| 570 |
+
# return {pdb_id: chain_ids_info}
|
PhysDock/data/tools/feature_processing_multimer.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
# Copyright 2022 AlQuraishi Laboratory
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Feature processing logic for multimer data pipeline."""
|
| 17 |
+
from typing import Iterable, MutableMapping, List, Mapping, Dict, Any, Union
|
| 18 |
+
from scipy.sparse import coo_matrix
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from . import msa_pairing
|
| 22 |
+
|
| 23 |
+
FeatureDict = Dict[str, Union[np.ndarray, coo_matrix, None, Any]]
|
| 24 |
+
# TODO: Move this into the config
|
| 25 |
+
REQUIRED_FEATURES = frozenset({
|
| 26 |
+
'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
|
| 27 |
+
'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
|
| 28 |
+
'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
|
| 29 |
+
'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
|
| 30 |
+
'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
|
| 31 |
+
'num_templates', 'queue_size', 'residue_index', 'resolution',
|
| 32 |
+
'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
|
| 33 |
+
'template_all_atom_mask', 'template_all_atom_positions'
|
| 34 |
+
})
|
| 35 |
+
|
| 36 |
+
MAX_TEMPLATES = 4
|
| 37 |
+
MSA_CROP_SIZE = 16384
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
|
| 41 |
+
"""Checks if a list of chains represents a homomer/monomer example."""
|
| 42 |
+
# Note that an entity_id of 0 indicates padding.
|
| 43 |
+
# num_unique_chains = len(np.unique(np.concatenate(
|
| 44 |
+
# [np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
|
| 45 |
+
# chain in chains])))
|
| 46 |
+
|
| 47 |
+
# return num_unique_chains == 1
|
| 48 |
+
num_chains = len(chains)
|
| 49 |
+
return num_chains == 1
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pair_and_merge(
|
| 53 |
+
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]],
|
| 54 |
+
is_homomer_or_monomer,
|
| 55 |
+
) -> FeatureDict:
|
| 56 |
+
"""Runs processing on features to augment, pair and merge.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
all_chain_features: A MutableMap of dictionaries of features for each chain.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
A dictionary of features.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
process_unmerged_features(all_chain_features)
|
| 66 |
+
|
| 67 |
+
np_chains_list = list(all_chain_features.values())
|
| 68 |
+
np_chains_list_prot = [chain for chain in np_chains_list if
|
| 69 |
+
chain['chain_class'] in ['protein']]
|
| 70 |
+
np_chains_list_dna = [chain for chain in np_chains_list if
|
| 71 |
+
chain['chain_class'] in ['dna']]
|
| 72 |
+
np_chains_list_rna = [chain for chain in np_chains_list if
|
| 73 |
+
chain['chain_class'] in ['rna', ]]
|
| 74 |
+
# TODO: ligand?
|
| 75 |
+
np_chains_list_ligand = [chain for chain in np_chains_list if chain['chain_class'] in ['ligand']]
|
| 76 |
+
# np_chains_list_ligand = []
|
| 77 |
+
|
| 78 |
+
# pair_msa_sequences_prot = False
|
| 79 |
+
# if np_chains_list_prot:
|
| 80 |
+
# pair_msa_sequences_prot = not _is_homomer_or_monomer(np_chains_list_prot)
|
| 81 |
+
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
|
| 82 |
+
pair_msa_sequences_prot = not is_homomer_or_monomer
|
| 83 |
+
if pair_msa_sequences_prot and np_chains_list_prot:
|
| 84 |
+
# uniprot : all_seq pairs
|
| 85 |
+
np_chains_list_prot = msa_pairing.create_paired_features(
|
| 86 |
+
chains=np_chains_list_prot
|
| 87 |
+
)
|
| 88 |
+
# deduplicate msa
|
| 89 |
+
# np_chains_list_prot = msa_pairing.deduplicate_unpaired_sequences(np_chains_list_prot)
|
| 90 |
+
else:
|
| 91 |
+
if np_chains_list_prot:
|
| 92 |
+
for prot in np_chains_list_prot:
|
| 93 |
+
prot["num_alignments"] = np.ones([], dtype=np.int32)
|
| 94 |
+
|
| 95 |
+
for chain in np_chains_list_prot:
|
| 96 |
+
chain.pop("msa_species_identifiers", None)
|
| 97 |
+
chain.pop("msa_species_identifiers_all_seq", None)
|
| 98 |
+
|
| 99 |
+
np_chains_list_prot.extend(np_chains_list_rna)
|
| 100 |
+
np_chains_list_prot.extend(np_chains_list_dna)
|
| 101 |
+
np_chains_list_prot.extend(np_chains_list_ligand)
|
| 102 |
+
|
| 103 |
+
np_chains_list = np_chains_list_prot
|
| 104 |
+
|
| 105 |
+
np_chains_list = crop_chains(
|
| 106 |
+
np_chains_list,
|
| 107 |
+
msa_crop_size=MSA_CROP_SIZE,
|
| 108 |
+
pair_msa_sequences=pair_msa_sequences_prot,
|
| 109 |
+
max_templates=MAX_TEMPLATES
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
np_example = msa_pairing.merge_chain_features(
|
| 113 |
+
np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
|
| 114 |
+
max_templates=MAX_TEMPLATES
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# np_example = print_final(np_example)
|
| 118 |
+
return np_example
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def crop_chains(
|
| 122 |
+
chains_list: List[Mapping[str, np.ndarray]],
|
| 123 |
+
msa_crop_size: int,
|
| 124 |
+
pair_msa_sequences: bool,
|
| 125 |
+
max_templates: int
|
| 126 |
+
) -> List[Mapping[str, np.ndarray]]:
|
| 127 |
+
"""Crops the MSAs for a set of chains.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
chains_list: A list of chains to be cropped.
|
| 131 |
+
msa_crop_size: The total number of sequences to crop from the MSA.
|
| 132 |
+
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
|
| 133 |
+
max_templates: The maximum templates to use per chain.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
The chains cropped.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# Apply the cropping.
|
| 140 |
+
cropped_chains = []
|
| 141 |
+
for chain in chains_list:
|
| 142 |
+
if chain['chain_class'] in ['protein']:
|
| 143 |
+
# print(chain['chain_class'])
|
| 144 |
+
cropped_chain = _crop_single_chain(
|
| 145 |
+
chain,
|
| 146 |
+
msa_crop_size=msa_crop_size,
|
| 147 |
+
pair_msa_sequences=pair_msa_sequences,
|
| 148 |
+
max_templates=max_templates)
|
| 149 |
+
else:
|
| 150 |
+
|
| 151 |
+
msa_size = chain['msa'].shape[0]
|
| 152 |
+
msa_size_array = np.arange(msa_size)
|
| 153 |
+
target_size = MSA_CROP_SIZE
|
| 154 |
+
if msa_size < target_size:
|
| 155 |
+
sample_msa_id = np.random.choice(msa_size_array, target_size - msa_size, replace=True)
|
| 156 |
+
sample_msa = chain['msa'][sample_msa_id, :]
|
| 157 |
+
chain['msa'] = np.concatenate([chain['msa'], sample_msa], axis=0)
|
| 158 |
+
sample_msa_del = chain['deletion_matrix'][sample_msa_id, :]
|
| 159 |
+
chain['deletion_matrix'] = np.concatenate([chain['deletion_matrix'], sample_msa_del], axis=0)
|
| 160 |
+
|
| 161 |
+
else:
|
| 162 |
+
chain['msa'] = chain['msa'][:target_size, :]
|
| 163 |
+
msa_size = chain['msa'].shape[0]
|
| 164 |
+
msa_size_array = np.arange(msa_size)
|
| 165 |
+
chain['deletion_matrix'] = chain['deletion_matrix'][:target_size, :]
|
| 166 |
+
|
| 167 |
+
cropped_chain = chain
|
| 168 |
+
cropped_chains.append(cropped_chain)
|
| 169 |
+
|
| 170 |
+
return cropped_chains
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _crop_single_chain(chain: Mapping[str, np.ndarray],
|
| 174 |
+
msa_crop_size: int,
|
| 175 |
+
pair_msa_sequences: bool,
|
| 176 |
+
max_templates: int) -> Mapping[str, np.ndarray]:
|
| 177 |
+
"""Crops msa sequences to `msa_crop_size`."""
|
| 178 |
+
msa_size = len(chain['msa'])
|
| 179 |
+
|
| 180 |
+
if pair_msa_sequences:
|
| 181 |
+
# print(chain.keys())
|
| 182 |
+
msa_size_all_seq = chain['msa_all_seq'].shape[0]
|
| 183 |
+
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
else:
|
| 188 |
+
|
| 189 |
+
msa_crop_size_all_seq = 0
|
| 190 |
+
|
| 191 |
+
include_templates = 'template_aatype' in chain and max_templates
|
| 192 |
+
if include_templates:
|
| 193 |
+
num_templates = chain['template_aatype'].shape[0]
|
| 194 |
+
templates_crop_size = np.minimum(num_templates, max_templates)
|
| 195 |
+
|
| 196 |
+
target_size = MSA_CROP_SIZE - msa_crop_size_all_seq
|
| 197 |
+
if msa_size < target_size:
|
| 198 |
+
sample_msa_id = np.random.choice(np.arange(msa_size), target_size - msa_size, replace=True)
|
| 199 |
+
for k in chain:
|
| 200 |
+
k_split = k.split('_all_seq')[0]
|
| 201 |
+
if k_split in msa_pairing.TEMPLATE_FEATURES:
|
| 202 |
+
chain[k] = chain[k][:templates_crop_size, :]
|
| 203 |
+
elif k_split in msa_pairing.MSA_FEATURES:
|
| 204 |
+
if '_all_seq' in k:
|
| 205 |
+
chain[k] = chain[k][:msa_crop_size_all_seq, :]
|
| 206 |
+
else:
|
| 207 |
+
|
| 208 |
+
if msa_size < target_size:
|
| 209 |
+
|
| 210 |
+
sample_msa = chain[k][sample_msa_id, :]
|
| 211 |
+
chain[k] = np.concatenate([chain[k], sample_msa], axis=0)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
chain[k] = chain[k][:target_size, :]
|
| 216 |
+
|
| 217 |
+
chain['num_alignments'] = np.asarray(len(chain['msa']), dtype=np.int32)
|
| 218 |
+
if include_templates:
|
| 219 |
+
chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
|
| 220 |
+
if pair_msa_sequences:
|
| 221 |
+
chain['num_alignments_all_seq'] = np.asarray(
|
| 222 |
+
len(chain['msa_all_seq']), dtype=np.int32)
|
| 223 |
+
|
| 224 |
+
return chain
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def print_final(
|
| 228 |
+
np_example: Mapping[str, np.ndarray]
|
| 229 |
+
) -> Mapping[str, np.ndarray]:
|
| 230 |
+
return np_example
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _filter_features(
|
| 234 |
+
np_example: Mapping[str, np.ndarray]
|
| 235 |
+
) -> Mapping[str, np.ndarray]:
|
| 236 |
+
"""Filters features of example to only those requested."""
|
| 237 |
+
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def process_unmerged_features(
|
| 241 |
+
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]]
|
| 242 |
+
):
|
| 243 |
+
"""Postprocessing stage for per-chain features before merging."""
|
| 244 |
+
num_chains = len(all_chain_features)
|
| 245 |
+
for chain_features in all_chain_features.values():
|
| 246 |
+
# if chain_features['chain_class'] in ['protein']:
|
| 247 |
+
chain_features['deletion_mean'] = np.mean(
|
| 248 |
+
chain_features['deletion_matrix'], axis=0
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Add assembly_num_chains.
|
| 252 |
+
chain_features['assembly_num_chains'] = np.asarray(num_chains)
|
| 253 |
+
|
| 254 |
+
# Add entity_mask.
|
| 255 |
+
for chain_features in all_chain_features.values():
|
| 256 |
+
chain_features['entity_mask'] = (
|
| 257 |
+
chain_features['entity_id'] != 0).astype(np.int32)
|
PhysDock/data/tools/get_metrics.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
sys.path.append("../")
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import scipy
|
| 12 |
+
import random
|
| 13 |
+
import logging
|
| 14 |
+
import collections
|
| 15 |
+
from functools import partial
|
| 16 |
+
from typing import Union, Tuple, Dict
|
| 17 |
+
import itertools
|
| 18 |
+
from PhysDock.utils.tensor_utils import tensor_tree_map
|
| 19 |
+
from PhysDock.utils.io_utils import load_json, load_txt,dump_json,dump_txt,dump_pkl, load_pkl
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _calculate_bin_centers(breaks: np.ndarray):
|
| 24 |
+
"""Gets the bin centers from the bin edges.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
breaks: [num_bins - 1] the error bin edges.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
bin_centers: [num_bins] the error bin centers.
|
| 31 |
+
"""
|
| 32 |
+
step = (breaks[1] - breaks[0])
|
| 33 |
+
|
| 34 |
+
# Add half-step to get the center
|
| 35 |
+
bin_centers = breaks + step / 2
|
| 36 |
+
# Add a catch-all bin at the end.
|
| 37 |
+
bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
|
| 38 |
+
axis=0)
|
| 39 |
+
return bin_centers
|
| 40 |
+
|
| 41 |
+
def _calculate_expected_aligned_error(
|
| 42 |
+
alignment_confidence_breaks: np.ndarray,
|
| 43 |
+
aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 44 |
+
"""Calculates expected aligned distance errors for every pair of residues.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
alignment_confidence_breaks: [num_bins - 1] the error bin edges.
|
| 48 |
+
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
|
| 49 |
+
probs for each error bin, for each pair of residues.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
predicted_aligned_error: [num_res, num_res] the expected aligned distance
|
| 53 |
+
error for each pair of residues.
|
| 54 |
+
max_predicted_aligned_error: The maximum predicted error possible.
|
| 55 |
+
"""
|
| 56 |
+
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
|
| 57 |
+
|
| 58 |
+
# Tuple of expected aligned distance error and max possible error.
|
| 59 |
+
return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1),
|
| 60 |
+
np.asarray(bin_centers[-1]))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def compute_plddt(logits: np.ndarray) -> np.ndarray:
|
| 64 |
+
"""Computes per-residue pLDDT from logits.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
logits: [num_res, num_bins] output from the PredictedLDDTHead.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
plddt: [num_res] per-residue pLDDT.
|
| 71 |
+
"""
|
| 72 |
+
num_bins = logits.shape[-1]
|
| 73 |
+
bin_width = 1.0 / num_bins
|
| 74 |
+
bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
|
| 75 |
+
probs = scipy.special.softmax(logits, axis=-1)
|
| 76 |
+
predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
|
| 77 |
+
return predicted_lddt_ca * 100
|
| 78 |
+
|
| 79 |
+
def predicted_tm_score(
|
| 80 |
+
logits: np.ndarray,
|
| 81 |
+
breaks: np.ndarray,
|
| 82 |
+
residue_weights: Optional[np.ndarray] = None,
|
| 83 |
+
asym_id: Optional[np.ndarray] = None,
|
| 84 |
+
interface: bool = False) -> np.ndarray:
|
| 85 |
+
"""Computes predicted TM alignment or predicted interface TM alignment score.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
logits: [num_res, num_res, num_bins] the logits output from
|
| 89 |
+
PredictedAlignedErrorHead.
|
| 90 |
+
breaks: [num_bins] the error bins.
|
| 91 |
+
residue_weights: [num_res] the per residue weights to use for the
|
| 92 |
+
expectation.
|
| 93 |
+
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
|
| 94 |
+
ipTM calculation, i.e. when interface=True.
|
| 95 |
+
interface: If True, interface predicted TM score is computed.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
ptm_score: The predicted TM alignment or the predicted iTM score.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
|
| 102 |
+
# exp. resolved head's probability.
|
| 103 |
+
if residue_weights is None:
|
| 104 |
+
residue_weights = np.ones(logits.shape[0])
|
| 105 |
+
|
| 106 |
+
bin_centers = _calculate_bin_centers(breaks)
|
| 107 |
+
|
| 108 |
+
num_res = int(np.sum(residue_weights))
|
| 109 |
+
# Clip num_res to avoid negative/undefined d0.
|
| 110 |
+
clipped_num_res = max(num_res, 19)
|
| 111 |
+
|
| 112 |
+
# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
|
| 113 |
+
# "Scoring function for automated assessment of protein structure template
|
| 114 |
+
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
|
| 115 |
+
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
|
| 116 |
+
|
| 117 |
+
# Convert logits to probs.
|
| 118 |
+
probs = scipy.special.softmax(logits, axis=-1)
|
| 119 |
+
|
| 120 |
+
# TM-Score term for every bin.
|
| 121 |
+
tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
|
| 122 |
+
# E_distances tm(distance).
|
| 123 |
+
predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
|
| 124 |
+
|
| 125 |
+
pair_mask = np.ones_like(predicted_tm_term, dtype=bool)
|
| 126 |
+
if interface:
|
| 127 |
+
pair_mask *= asym_id[:, None] != asym_id[None, :]
|
| 128 |
+
|
| 129 |
+
predicted_tm_term *= pair_mask
|
| 130 |
+
|
| 131 |
+
pair_residue_weights = pair_mask * (
|
| 132 |
+
residue_weights[None, :] * residue_weights[:, None])
|
| 133 |
+
normed_residue_mask = pair_residue_weights / (1e-8 + np.sum(
|
| 134 |
+
pair_residue_weights, axis=-1, keepdims=True))
|
| 135 |
+
per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
|
| 136 |
+
return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def compute_predicted_aligned_error(
|
| 140 |
+
logits: np.ndarray,
|
| 141 |
+
breaks: np.ndarray) -> Dict[str, np.ndarray]:
|
| 142 |
+
"""Computes aligned confidence metrics from logits.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
logits: [num_res, num_res, num_bins] the logits output from
|
| 146 |
+
PredictedAlignedErrorHead.
|
| 147 |
+
breaks: [num_bins - 1] the error bin edges.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
|
| 151 |
+
aligned error probabilities over bins for each residue pair.
|
| 152 |
+
predicted_aligned_error: [num_res, num_res] the expected aligned distance
|
| 153 |
+
error for each pair of residues.
|
| 154 |
+
max_predicted_aligned_error: The maximum predicted error possible.
|
| 155 |
+
"""
|
| 156 |
+
aligned_confidence_probs = scipy.special.softmax(
|
| 157 |
+
logits,
|
| 158 |
+
axis=-1)
|
| 159 |
+
predicted_aligned_error, max_predicted_aligned_error = (
|
| 160 |
+
_calculate_expected_aligned_error(
|
| 161 |
+
alignment_confidence_breaks=breaks,
|
| 162 |
+
aligned_distance_error_probs=aligned_confidence_probs))
|
| 163 |
+
return {
|
| 164 |
+
'aligned_confidence_probs': aligned_confidence_probs,
|
| 165 |
+
'predicted_aligned_error': predicted_aligned_error,
|
| 166 |
+
'max_predicted_aligned_error': max_predicted_aligned_error,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
def get_has_clash(atom_pos, atom_mask, asym_id, is_polymer_chain):
|
| 170 |
+
"""
|
| 171 |
+
A structure is marked as having a clash (has_clash) if for any two
|
| 172 |
+
polymer chains A,B in the prediction clashes(A,B) > 100 or
|
| 173 |
+
clashes(A,B) / min(NA,NB) > 0.5 where NA is the number of atoms in
|
| 174 |
+
chain A.
|
| 175 |
+
Args:
|
| 176 |
+
atom_pos: [N_atom, 3]
|
| 177 |
+
atom_mask: [N_atom]
|
| 178 |
+
asym_id: [N_atom]
|
| 179 |
+
is_polymer_chain: [N_atom]
|
| 180 |
+
"""
|
| 181 |
+
flag = np.logical_and(atom_mask == 1, is_polymer_chain == 1)
|
| 182 |
+
atom_pos = atom_pos[flag]
|
| 183 |
+
asym_id = asym_id[flag]
|
| 184 |
+
uniq_asym_ids = np.unique(asym_id)
|
| 185 |
+
n = len(uniq_asym_ids)
|
| 186 |
+
if n == 1:
|
| 187 |
+
return 0
|
| 188 |
+
for aid1 in uniq_asym_ids[:-1]:
|
| 189 |
+
for aid2 in uniq_asym_ids[1:]:
|
| 190 |
+
pos1 = atom_pos[asym_id == aid1]
|
| 191 |
+
pos2 = atom_pos[asym_id == aid2]
|
| 192 |
+
dist = np.sqrt(np.sum((pos1[None] - pos2[:, None]) ** 2, -1))
|
| 193 |
+
n_clash = np.sum(dist < 1.1).astype('float32')
|
| 194 |
+
if n_clash > 100 or n_clash / min(len(pos1), len(pos2)) > 0.5:
|
| 195 |
+
return 1
|
| 196 |
+
return 0
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_metrics(output, batch):
|
| 202 |
+
"""
|
| 203 |
+
Args:
|
| 204 |
+
logits_plddt: (B, N_atom, b_plddt)
|
| 205 |
+
logits_pae: (B, N_token, N_token, b_pae)
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
atom_plddts: (B, N_atom)
|
| 209 |
+
mean_plddt: (B,)
|
| 210 |
+
pae: (B, N_token, N_token)
|
| 211 |
+
ptm: (B,)
|
| 212 |
+
iptm: (B,)
|
| 213 |
+
has_clash: (B,)
|
| 214 |
+
ranking_confidence: (B,)
|
| 215 |
+
"""
|
| 216 |
+
logit_value = output
|
| 217 |
+
|
| 218 |
+
# B = logit_value['p_pae'].shape[0]
|
| 219 |
+
breaks_pae = torch.linspace(0.,
|
| 220 |
+
0.5 * 64,
|
| 221 |
+
64 - 1)
|
| 222 |
+
inputs = {
|
| 223 |
+
's_mask': batch['s_mask'],
|
| 224 |
+
'asym_id': batch['asym_id'],
|
| 225 |
+
'breaks_pae': torch.tile(breaks_pae, [ 1]),
|
| 226 |
+
# 'perm_asym_id': batch['perm_asym_id'],
|
| 227 |
+
'is_polymer_chain': ((batch['is_protein'] +
|
| 228 |
+
batch['is_dna'] + batch['is_rna']) > 0),
|
| 229 |
+
**logit_value,
|
| 230 |
+
**batch
|
| 231 |
+
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
ret_list = []
|
| 235 |
+
# for i in range(B):
|
| 236 |
+
cur_input = tensor_tree_map(lambda x: x.numpy(), inputs)
|
| 237 |
+
# cur_input = inputs
|
| 238 |
+
ret = get_all_atom_confidence_metrics(cur_input,0)
|
| 239 |
+
ret_list.append(ret)
|
| 240 |
+
|
| 241 |
+
metrics = {}
|
| 242 |
+
for k, v in ret_list[0].items():
|
| 243 |
+
metrics[k] = torch.from_numpy(np.stack([r[k] for r in ret_list]))
|
| 244 |
+
return metrics
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_all_atom_confidence_metrics(
|
| 249 |
+
prediction_result,b):
|
| 250 |
+
"""get_all_atom_confidence_metrics."""
|
| 251 |
+
metrics = {}
|
| 252 |
+
metrics['atom_plddts'] = compute_plddt(
|
| 253 |
+
prediction_result['p_plddt'])
|
| 254 |
+
metrics['mean_plddt'] = metrics['atom_plddts'].mean()
|
| 255 |
+
metrics['pae'] = compute_predicted_aligned_error(
|
| 256 |
+
logits=prediction_result['p_pae'],
|
| 257 |
+
breaks=prediction_result['breaks_pae'])['predicted_aligned_error']
|
| 258 |
+
metrics['ptm'] = predicted_tm_score(
|
| 259 |
+
logits=prediction_result['p_pae'],
|
| 260 |
+
breaks=prediction_result['breaks_pae'],
|
| 261 |
+
residue_weights=prediction_result['s_mask'],
|
| 262 |
+
asym_id=None)
|
| 263 |
+
metrics['iptm'] = predicted_tm_score(
|
| 264 |
+
logits=prediction_result['p_pae'],
|
| 265 |
+
breaks=prediction_result['breaks_pae'],
|
| 266 |
+
residue_weights=prediction_result['s_mask'],
|
| 267 |
+
asym_id=prediction_result['asym_id'],
|
| 268 |
+
interface=True)
|
| 269 |
+
metrics['has_clash'] = get_has_clash(
|
| 270 |
+
prediction_result['x_pred'][b],
|
| 271 |
+
prediction_result['a_mask'],
|
| 272 |
+
prediction_result['asym_id'][prediction_result["atom_id_to_token_id"]],
|
| 273 |
+
~prediction_result['is_ligand'][prediction_result["atom_id_to_token_id"]])
|
| 274 |
+
metrics['ranking_confidence'] = (
|
| 275 |
+
0.8 * metrics['iptm'] + 0.2 * metrics['ptm']
|
| 276 |
+
- 1.0 * metrics['has_clash'])
|
| 277 |
+
return metrics
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# output = load_pkl("../output.pkl.gz")
|
| 281 |
+
# feats = load_pkl("../feats.pkl.gz")
|
| 282 |
+
# for k,v in output.items():
|
| 283 |
+
# print(k,v.shape)
|
| 284 |
+
# # output[k] = torch.from_numpy(v)
|
| 285 |
+
# for k,v in feats.items():
|
| 286 |
+
# print(k,v.shape)
|
| 287 |
+
# # feats[k] = torch.from_numpy(v)
|
| 288 |
+
# # dump_pkl(output,"../output.pkl.gz")
|
| 289 |
+
# # dump_pkl(feats,"../feats.pkl.gz")
|
| 290 |
+
#
|
| 291 |
+
# metrics = get_metrics(output,feats)
|
| 292 |
+
#
|
| 293 |
+
# for k,v in metrics.items():
|
| 294 |
+
# print(k,v)
|
PhysDock/data/tools/hhblits.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Library to run HHblits from Python."""
|
| 17 |
+
import glob
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import subprocess
|
| 21 |
+
from typing import Any, List, Mapping, Optional, Sequence
|
| 22 |
+
|
| 23 |
+
from . import utils
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_HHBLITS_DEFAULT_P = 20
|
| 27 |
+
_HHBLITS_DEFAULT_Z = 500
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class HHBlits:
|
| 31 |
+
"""Python wrapper of the HHblits binary."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
binary_path: str,
|
| 37 |
+
databases: Sequence[str],
|
| 38 |
+
n_cpu: int = 4,
|
| 39 |
+
n_iter: int = 3,
|
| 40 |
+
e_value: float = 0.001,
|
| 41 |
+
maxseq: int = 1_000_000,
|
| 42 |
+
realign_max: int = 100_000,
|
| 43 |
+
maxfilt: int = 100_000,
|
| 44 |
+
min_prefilter_hits: int = 1000,
|
| 45 |
+
all_seqs: bool = False,
|
| 46 |
+
alt: Optional[int] = None,
|
| 47 |
+
p: int = _HHBLITS_DEFAULT_P,
|
| 48 |
+
z: int = _HHBLITS_DEFAULT_Z,
|
| 49 |
+
):
|
| 50 |
+
"""Initializes the Python HHblits wrapper.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
binary_path: The path to the HHblits executable.
|
| 54 |
+
databases: A sequence of HHblits database paths. This should be the
|
| 55 |
+
common prefix for the database files (i.e. up to but not including
|
| 56 |
+
_hhm.ffindex etc.)
|
| 57 |
+
n_cpu: The number of CPUs to give HHblits.
|
| 58 |
+
n_iter: The number of HHblits iterations.
|
| 59 |
+
e_value: The E-value, see HHblits docs for more details.
|
| 60 |
+
maxseq: The maximum number of rows in an input alignment. Note that this
|
| 61 |
+
parameter is only supported in HHBlits version 3.1 and higher.
|
| 62 |
+
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
|
| 63 |
+
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
|
| 64 |
+
HHblits default: 20000.
|
| 65 |
+
min_prefilter_hits: Min number of hits to pass prefilter.
|
| 66 |
+
HHblits default: 100.
|
| 67 |
+
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
|
| 68 |
+
HHblits default: False.
|
| 69 |
+
alt: Show up to this many alternative alignments.
|
| 70 |
+
p: Minimum Prob for a hit to be included in the output hhr file.
|
| 71 |
+
HHblits default: 20.
|
| 72 |
+
z: Hard cap on number of hits reported in the hhr file.
|
| 73 |
+
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
|
| 74 |
+
|
| 75 |
+
Raises:
|
| 76 |
+
RuntimeError: If HHblits binary not found within the path.
|
| 77 |
+
"""
|
| 78 |
+
self.binary_path = binary_path
|
| 79 |
+
self.databases = databases
|
| 80 |
+
|
| 81 |
+
for database_path in self.databases:
|
| 82 |
+
if not glob.glob(database_path + "_*"):
|
| 83 |
+
logging.error(
|
| 84 |
+
"Could not find HHBlits database %s", database_path
|
| 85 |
+
)
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Could not find HHBlits database {database_path}"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self.n_cpu = n_cpu
|
| 91 |
+
self.n_iter = n_iter
|
| 92 |
+
self.e_value = e_value
|
| 93 |
+
self.maxseq = maxseq
|
| 94 |
+
self.realign_max = realign_max
|
| 95 |
+
self.maxfilt = maxfilt
|
| 96 |
+
self.min_prefilter_hits = min_prefilter_hits
|
| 97 |
+
self.all_seqs = all_seqs
|
| 98 |
+
self.alt = alt
|
| 99 |
+
self.p = p
|
| 100 |
+
self.z = z
|
| 101 |
+
|
| 102 |
+
def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
|
| 103 |
+
"""Queries the database using HHblits."""
|
| 104 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 105 |
+
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
|
| 106 |
+
|
| 107 |
+
db_cmd = []
|
| 108 |
+
for db_path in self.databases:
|
| 109 |
+
db_cmd.append("-d")
|
| 110 |
+
db_cmd.append(db_path)
|
| 111 |
+
cmd = [
|
| 112 |
+
self.binary_path,
|
| 113 |
+
"-i",
|
| 114 |
+
input_fasta_path,
|
| 115 |
+
"-cpu",
|
| 116 |
+
str(self.n_cpu),
|
| 117 |
+
"-oa3m",
|
| 118 |
+
a3m_path,
|
| 119 |
+
"-o",
|
| 120 |
+
"/dev/null",
|
| 121 |
+
"-n",
|
| 122 |
+
str(self.n_iter),
|
| 123 |
+
"-e",
|
| 124 |
+
str(self.e_value),
|
| 125 |
+
"-maxseq",
|
| 126 |
+
str(self.maxseq),
|
| 127 |
+
"-realign_max",
|
| 128 |
+
str(self.realign_max),
|
| 129 |
+
"-maxfilt",
|
| 130 |
+
str(self.maxfilt),
|
| 131 |
+
"-min_prefilter_hits",
|
| 132 |
+
str(self.min_prefilter_hits),
|
| 133 |
+
]
|
| 134 |
+
if self.all_seqs:
|
| 135 |
+
cmd += ["-all"]
|
| 136 |
+
if self.alt:
|
| 137 |
+
cmd += ["-alt", str(self.alt)]
|
| 138 |
+
if self.p != _HHBLITS_DEFAULT_P:
|
| 139 |
+
cmd += ["-p", str(self.p)]
|
| 140 |
+
if self.z != _HHBLITS_DEFAULT_Z:
|
| 141 |
+
cmd += ["-Z", str(self.z)]
|
| 142 |
+
cmd += db_cmd
|
| 143 |
+
|
| 144 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 145 |
+
process = subprocess.Popen(
|
| 146 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with utils.timing("HHblits query"):
|
| 150 |
+
stdout, stderr = process.communicate()
|
| 151 |
+
retcode = process.wait()
|
| 152 |
+
|
| 153 |
+
if retcode:
|
| 154 |
+
# Logs have a 15k character limit, so log HHblits error line by line.
|
| 155 |
+
logging.error("HHblits failed. HHblits stderr begin:")
|
| 156 |
+
for error_line in stderr.decode("utf-8").splitlines():
|
| 157 |
+
if error_line.strip():
|
| 158 |
+
logging.error(error_line.strip())
|
| 159 |
+
logging.error("HHblits stderr end")
|
| 160 |
+
raise RuntimeError(
|
| 161 |
+
"HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
|
| 162 |
+
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
with open(a3m_path) as f:
|
| 166 |
+
a3m = f.read()
|
| 167 |
+
|
| 168 |
+
raw_output = dict(
|
| 169 |
+
a3m=a3m,
|
| 170 |
+
output=stdout,
|
| 171 |
+
stderr=stderr,
|
| 172 |
+
n_iter=self.n_iter,
|
| 173 |
+
e_value=self.e_value,
|
| 174 |
+
)
|
| 175 |
+
return [raw_output]
|
PhysDock/data/tools/hhsearch.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Library to run HHsearch from Python."""
|
| 17 |
+
import glob
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import subprocess
|
| 21 |
+
from typing import Sequence, Optional
|
| 22 |
+
|
| 23 |
+
from . import parsers
|
| 24 |
+
from . import utils
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HHSearch:
|
| 28 |
+
"""Python wrapper of the HHsearch binary."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
*,
|
| 33 |
+
binary_path: str,
|
| 34 |
+
databases: Sequence[str],
|
| 35 |
+
n_cpu: int = 2,
|
| 36 |
+
maxseq: int = 1_000_000,
|
| 37 |
+
):
|
| 38 |
+
"""Initializes the Python HHsearch wrapper.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
binary_path: The path to the HHsearch executable.
|
| 42 |
+
databases: A sequence of HHsearch database paths. This should be the
|
| 43 |
+
common prefix for the database files (i.e. up to but not including
|
| 44 |
+
_hhm.ffindex etc.)
|
| 45 |
+
n_cpu: The number of CPUs to use
|
| 46 |
+
maxseq: The maximum number of rows in an input alignment. Note that this
|
| 47 |
+
parameter is only supported in HHBlits version 3.1 and higher.
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
RuntimeError: If HHsearch binary not found within the path.
|
| 51 |
+
"""
|
| 52 |
+
self.binary_path = binary_path
|
| 53 |
+
self.databases = databases
|
| 54 |
+
self.n_cpu = n_cpu
|
| 55 |
+
self.maxseq = maxseq
|
| 56 |
+
|
| 57 |
+
for database_path in self.databases:
|
| 58 |
+
if not glob.glob(database_path + "_*"):
|
| 59 |
+
logging.error(
|
| 60 |
+
"Could not find HHsearch database %s", database_path
|
| 61 |
+
)
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"Could not find HHsearch database {database_path}"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def output_format(self) -> str:
|
| 68 |
+
return 'hhr'
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def input_format(self) -> str:
|
| 72 |
+
# return 'sto'
|
| 73 |
+
return 'a3m'
|
| 74 |
+
|
| 75 |
+
def query(self, a3m: str, output_dir: Optional[str] = None) -> str:
|
| 76 |
+
"""Queries the database using HHsearch using a given a3m."""
|
| 77 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 78 |
+
input_path = os.path.join(query_tmp_dir, "query.a3m")
|
| 79 |
+
output_dir = query_tmp_dir if output_dir is None else output_dir
|
| 80 |
+
hhr_path = os.path.join(output_dir, "hhsearch_output.hhr")
|
| 81 |
+
with open(input_path, "w") as f:
|
| 82 |
+
f.write(a3m)
|
| 83 |
+
|
| 84 |
+
db_cmd = []
|
| 85 |
+
for db_path in self.databases:
|
| 86 |
+
db_cmd.append("-d")
|
| 87 |
+
db_cmd.append(db_path)
|
| 88 |
+
cmd = [
|
| 89 |
+
self.binary_path,
|
| 90 |
+
"-i",
|
| 91 |
+
input_path,
|
| 92 |
+
"-o",
|
| 93 |
+
hhr_path,
|
| 94 |
+
"-maxseq",
|
| 95 |
+
str(self.maxseq),
|
| 96 |
+
"-cpu",
|
| 97 |
+
str(self.n_cpu),
|
| 98 |
+
] + db_cmd
|
| 99 |
+
|
| 100 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 101 |
+
process = subprocess.Popen(
|
| 102 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 103 |
+
)
|
| 104 |
+
with utils.timing("HHsearch query"):
|
| 105 |
+
stdout, stderr = process.communicate()
|
| 106 |
+
retcode = process.wait()
|
| 107 |
+
|
| 108 |
+
# if retcode:
|
| 109 |
+
# # Stderr is truncated to prevent proto size errors in Beam.
|
| 110 |
+
# raise RuntimeError(
|
| 111 |
+
# "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
|
| 112 |
+
# % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
|
| 113 |
+
# )
|
| 114 |
+
|
| 115 |
+
with open(hhr_path) as f:
|
| 116 |
+
hhr = f.read()
|
| 117 |
+
return hhr
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def get_template_hits(
|
| 121 |
+
output_string: str,
|
| 122 |
+
input_sequence: str
|
| 123 |
+
) -> Sequence[parsers.TemplateHit]:
|
| 124 |
+
"""Gets parsed template hits from the raw string output by the tool"""
|
| 125 |
+
del input_sequence # Used by hmmsearch but not needed for hhsearch
|
| 126 |
+
return parsers.parse_hhr(output_string)
|
PhysDock/data/tools/hmmalign.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from typing import Optional, Sequence
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
from . import parsers
|
| 7 |
+
from . import hmmbuild
|
| 8 |
+
from . import utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Hmmalign(object):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
*,
|
| 15 |
+
hmmbuild_binary_path: str,
|
| 16 |
+
hmmalign_binary_path: str,
|
| 17 |
+
):
|
| 18 |
+
self.binary_path = hmmalign_binary_path
|
| 19 |
+
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def output_format(self) -> str:
|
| 23 |
+
return 'sto'
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def input_format(self) -> str:
|
| 27 |
+
return 'sto'
|
| 28 |
+
|
| 29 |
+
def realign_sto_with_fasta(self, input_fasta_path, input_sto_path, output_sto_path: Optional = None) -> str:
|
| 30 |
+
delete_out = False if output_sto_path is not None else True
|
| 31 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 32 |
+
hmm_output_path = os.path.join(query_tmp_dir, 'query.hmm')
|
| 33 |
+
output_sto_path = os.path.join(query_tmp_dir,
|
| 34 |
+
"realigned.sto") if output_sto_path is None else output_sto_path
|
| 35 |
+
with open(input_fasta_path, "r") as f:
|
| 36 |
+
hmm = self.hmmbuild_runner.build_rna_profile_from_fasta(f.read())
|
| 37 |
+
with open(hmm_output_path, 'w') as f:
|
| 38 |
+
f.write(hmm)
|
| 39 |
+
|
| 40 |
+
cmd = [
|
| 41 |
+
self.binary_path,
|
| 42 |
+
'--rna', # Don't include the alignment in stdout.
|
| 43 |
+
'--mapali', input_fasta_path,
|
| 44 |
+
"-o", output_sto_path,
|
| 45 |
+
hmm_output_path,
|
| 46 |
+
input_sto_path
|
| 47 |
+
]
|
| 48 |
+
# print(cmd)
|
| 49 |
+
# print(" ".join(cmd))
|
| 50 |
+
logging.info('Launching sub-process %s', cmd)
|
| 51 |
+
process = subprocess.Popen(
|
| 52 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 53 |
+
with utils.timing(
|
| 54 |
+
f'hmmsearch query'):
|
| 55 |
+
stdout, stderr = process.communicate()
|
| 56 |
+
retcode = process.wait()
|
| 57 |
+
|
| 58 |
+
if retcode:
|
| 59 |
+
raise RuntimeError(
|
| 60 |
+
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
|
| 61 |
+
stdout.decode('utf-8'), stderr.decode('utf-8')))
|
| 62 |
+
if delete_out:
|
| 63 |
+
with open(output_sto_path) as f:
|
| 64 |
+
out_msa = f.read()
|
| 65 |
+
if delete_out:
|
| 66 |
+
return out_msa
|
PhysDock/data/tools/hmmbuild.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
import subprocess
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
from . import utils
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Hmmbuild(object):
|
| 26 |
+
"""Python wrapper of the hmmbuild binary."""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
*,
|
| 30 |
+
binary_path: str,
|
| 31 |
+
singlemx: bool = False):
|
| 32 |
+
"""Initializes the Python hmmbuild wrapper.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
binary_path: The path to the hmmbuild executable.
|
| 36 |
+
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
|
| 37 |
+
just use a common substitution score matrix.
|
| 38 |
+
|
| 39 |
+
Raises:
|
| 40 |
+
RuntimeError: If hmmbuild binary not found within the path.
|
| 41 |
+
"""
|
| 42 |
+
self.binary_path = binary_path
|
| 43 |
+
self.singlemx = singlemx
|
| 44 |
+
|
| 45 |
+
def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
|
| 46 |
+
"""Builds a HHM for the aligned sequences given as an A3M string.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
sto: A string with the aligned sequences in the Stockholm format.
|
| 50 |
+
model_construction: Whether to use reference annotation in the msa to
|
| 51 |
+
determine consensus columns ('hand') or default ('fast').
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
A string with the profile in the HMM format.
|
| 55 |
+
|
| 56 |
+
Raises:
|
| 57 |
+
RuntimeError: If hmmbuild fails.
|
| 58 |
+
"""
|
| 59 |
+
return self._build_profile(sto, model_construction=model_construction)
|
| 60 |
+
|
| 61 |
+
def build_profile_from_a3m(self, a3m: str) -> str:
|
| 62 |
+
"""Builds a HHM for the aligned sequences given as an A3M string.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
a3m: A string with the aligned sequences in the A3M format.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
A string with the profile in the HMM format.
|
| 69 |
+
|
| 70 |
+
Raises:
|
| 71 |
+
RuntimeError: If hmmbuild fails.
|
| 72 |
+
"""
|
| 73 |
+
lines = []
|
| 74 |
+
for line in a3m.splitlines():
|
| 75 |
+
if not line.startswith('>'):
|
| 76 |
+
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
|
| 77 |
+
lines.append(line + '\n')
|
| 78 |
+
msa = ''.join(lines)
|
| 79 |
+
return self._build_profile(msa, model_construction='fast')
|
| 80 |
+
|
| 81 |
+
def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
|
| 82 |
+
"""Builds a HMM for the aligned sequences given as an MSA string.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
msa: A string with the aligned sequences, in A3M or STO format.
|
| 86 |
+
model_construction: Whether to use reference annotation in the msa to
|
| 87 |
+
determine consensus columns ('hand') or default ('fast').
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
A string with the profile in the HMM format.
|
| 91 |
+
|
| 92 |
+
Raises:
|
| 93 |
+
RuntimeError: If hmmbuild fails.
|
| 94 |
+
ValueError: If unspecified arguments are provided.
|
| 95 |
+
"""
|
| 96 |
+
if model_construction not in {'hand', 'fast'}:
|
| 97 |
+
raise ValueError(f'Invalid model_construction {model_construction} - only'
|
| 98 |
+
'hand and fast supported.')
|
| 99 |
+
|
| 100 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 101 |
+
input_query = os.path.join(query_tmp_dir, 'query.msa')
|
| 102 |
+
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
|
| 103 |
+
|
| 104 |
+
with open(input_query, 'w') as f:
|
| 105 |
+
f.write(msa)
|
| 106 |
+
|
| 107 |
+
cmd = [self.binary_path]
|
| 108 |
+
# If adding flags, we have to do so before the output and input:
|
| 109 |
+
|
| 110 |
+
if model_construction == 'hand':
|
| 111 |
+
cmd.append(f'--{model_construction}')
|
| 112 |
+
if self.singlemx:
|
| 113 |
+
cmd.append('--singlemx')
|
| 114 |
+
cmd.extend([
|
| 115 |
+
'--amino',
|
| 116 |
+
output_hmm_path,
|
| 117 |
+
input_query,
|
| 118 |
+
])
|
| 119 |
+
|
| 120 |
+
logging.info('Launching subprocess %s', cmd)
|
| 121 |
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
|
| 122 |
+
stderr=subprocess.PIPE)
|
| 123 |
+
|
| 124 |
+
with utils.timing('hmmbuild query'):
|
| 125 |
+
stdout, stderr = process.communicate()
|
| 126 |
+
retcode = process.wait()
|
| 127 |
+
logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
|
| 128 |
+
stdout.decode('utf-8'), stderr.decode('utf-8'))
|
| 129 |
+
|
| 130 |
+
if retcode:
|
| 131 |
+
raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
|
| 132 |
+
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
|
| 133 |
+
|
| 134 |
+
with open(output_hmm_path, encoding='utf-8') as f:
|
| 135 |
+
hmm = f.read()
|
| 136 |
+
|
| 137 |
+
return hmm
|
| 138 |
+
|
| 139 |
+
def build_rna_profile_from_fasta(self, fasta: str):
|
| 140 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 141 |
+
input_query = os.path.join(query_tmp_dir, 'query.fasta')
|
| 142 |
+
output_hmm_path = os.path.join(query_tmp_dir, 'query.hmm')
|
| 143 |
+
with open(input_query, 'w') as f:
|
| 144 |
+
f.write(fasta)
|
| 145 |
+
cmd = [self.binary_path]
|
| 146 |
+
cmd.extend([
|
| 147 |
+
'--rna',
|
| 148 |
+
output_hmm_path,
|
| 149 |
+
input_query,
|
| 150 |
+
])
|
| 151 |
+
logging.info('Launching subprocess %s', cmd)
|
| 152 |
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
|
| 153 |
+
stderr=subprocess.PIPE)
|
| 154 |
+
with utils.timing('hmmbuild query'):
|
| 155 |
+
stdout, stderr = process.communicate()
|
| 156 |
+
retcode = process.wait()
|
| 157 |
+
logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
|
| 158 |
+
stdout.decode('utf-8'), stderr.decode('utf-8'))
|
| 159 |
+
if retcode:
|
| 160 |
+
raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
|
| 161 |
+
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
|
| 162 |
+
|
| 163 |
+
with open(output_hmm_path, encoding='utf-8') as f:
|
| 164 |
+
hmm = f.read()
|
| 165 |
+
return hmm
|
PhysDock/data/tools/hmmsearch.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import subprocess
|
| 19 |
+
from typing import Optional, Sequence
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
from . import parsers
|
| 23 |
+
from . import hmmbuild
|
| 24 |
+
from . import utils
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Hmmsearch(object):
|
| 28 |
+
"""Python wrapper of the hmmsearch binary."""
|
| 29 |
+
|
| 30 |
+
def __init__(self,
|
| 31 |
+
*,
|
| 32 |
+
binary_path: str,
|
| 33 |
+
hmmbuild_binary_path: str,
|
| 34 |
+
database_path: str,
|
| 35 |
+
flags: Optional[Sequence[str]] = None
|
| 36 |
+
):
|
| 37 |
+
"""Initializes the Python hmmsearch wrapper.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
binary_path: The path to the hmmsearch executable.
|
| 41 |
+
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
|
| 42 |
+
an hmm from an input a3m.
|
| 43 |
+
database_path: The path to the hmmsearch database (FASTA format).
|
| 44 |
+
flags: List of flags to be used by hmmsearch.
|
| 45 |
+
|
| 46 |
+
Raises:
|
| 47 |
+
RuntimeError: If hmmsearch binary not found within the path.
|
| 48 |
+
"""
|
| 49 |
+
self.binary_path = binary_path
|
| 50 |
+
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
|
| 51 |
+
self.database_path = database_path
|
| 52 |
+
if flags is None:
|
| 53 |
+
# Default hmmsearch run settings.
|
| 54 |
+
flags = ['--F1', '0.1',
|
| 55 |
+
'--F2', '0.1',
|
| 56 |
+
'--F3', '0.1',
|
| 57 |
+
'--incE', '100',
|
| 58 |
+
'-E', '100',
|
| 59 |
+
'--domE', '100',
|
| 60 |
+
'--incdomE', '100']
|
| 61 |
+
self.flags = flags
|
| 62 |
+
|
| 63 |
+
if not os.path.exists(self.database_path):
|
| 64 |
+
logging.error('Could not find hmmsearch database %s', database_path)
|
| 65 |
+
raise ValueError(f'Could not find hmmsearch database {database_path}')
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def output_format(self) -> str:
|
| 69 |
+
return 'sto'
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def input_format(self) -> str:
|
| 73 |
+
return 'sto'
|
| 74 |
+
|
| 75 |
+
def query(self, msa_sto: str, output_dir: Optional[str] = None) -> str:
|
| 76 |
+
"""Queries the database using hmmsearch using a given stockholm msa."""
|
| 77 |
+
hmm = self.hmmbuild_runner.build_profile_from_sto(
|
| 78 |
+
msa_sto,
|
| 79 |
+
model_construction='hand'
|
| 80 |
+
)
|
| 81 |
+
return self.query_with_hmm(hmm, output_dir)
|
| 82 |
+
|
| 83 |
+
def query_with_hmm(self,
|
| 84 |
+
hmm: str,
|
| 85 |
+
output_dir: Optional[str] = None
|
| 86 |
+
) -> str:
|
| 87 |
+
"""Queries the database using hmmsearch using a given hmm."""
|
| 88 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 89 |
+
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
|
| 90 |
+
output_dir = query_tmp_dir if output_dir is None else output_dir
|
| 91 |
+
out_path = os.path.join(output_dir, 'hmm_output.sto')
|
| 92 |
+
with open(hmm_input_path, 'w') as f:
|
| 93 |
+
f.write(hmm)
|
| 94 |
+
|
| 95 |
+
cmd = [
|
| 96 |
+
self.binary_path,
|
| 97 |
+
'--noali', # Don't include the alignment in stdout.
|
| 98 |
+
'--cpu', '8'
|
| 99 |
+
]
|
| 100 |
+
# If adding flags, we have to do so before the output and input:
|
| 101 |
+
if self.flags:
|
| 102 |
+
cmd.extend(self.flags)
|
| 103 |
+
cmd.extend([
|
| 104 |
+
'-A', out_path,
|
| 105 |
+
hmm_input_path,
|
| 106 |
+
self.database_path,
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
logging.info('Launching sub-process %s', cmd)
|
| 110 |
+
process = subprocess.Popen(
|
| 111 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 112 |
+
with utils.timing(
|
| 113 |
+
f'hmmsearch ({os.path.basename(self.database_path)}) query'):
|
| 114 |
+
stdout, stderr = process.communicate()
|
| 115 |
+
retcode = process.wait()
|
| 116 |
+
|
| 117 |
+
if retcode:
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
|
| 120 |
+
stdout.decode('utf-8'), stderr.decode('utf-8')))
|
| 121 |
+
|
| 122 |
+
with open(out_path) as f:
|
| 123 |
+
out_msa = f.read()
|
| 124 |
+
|
| 125 |
+
return out_msa
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def get_template_hits(
|
| 129 |
+
output_string: str,
|
| 130 |
+
input_sequence: str
|
| 131 |
+
) -> Sequence[parsers.TemplateHit]:
|
| 132 |
+
"""Gets parsed template hits from the raw string output by the tool."""
|
| 133 |
+
template_hits = parsers.parse_hmmsearch_sto(
|
| 134 |
+
output_string,
|
| 135 |
+
input_sequence,
|
| 136 |
+
)
|
| 137 |
+
return template_hits
|
PhysDock/data/tools/jackhmmer.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Library to run Jackhmmer from Python."""
|
| 17 |
+
|
| 18 |
+
from concurrent import futures
|
| 19 |
+
import glob
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import subprocess
|
| 23 |
+
from typing import Any, Callable, Mapping, Optional, Sequence
|
| 24 |
+
from urllib import request
|
| 25 |
+
|
| 26 |
+
from . import parsers
|
| 27 |
+
from . import utils
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Jackhmmer:
|
| 31 |
+
"""Python wrapper of the Jackhmmer binary."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
binary_path: str,
|
| 37 |
+
database_path: str,
|
| 38 |
+
n_cpu: int = 8,
|
| 39 |
+
n_iter: int = 1,
|
| 40 |
+
e_value: float = 0.0001,
|
| 41 |
+
z_value: Optional[int] = None,
|
| 42 |
+
get_tblout: bool = False,
|
| 43 |
+
filter_f1: float = 0.0005,
|
| 44 |
+
filter_f2: float = 0.00005,
|
| 45 |
+
filter_f3: float = 0.0000005,
|
| 46 |
+
seq_limit: int = 50000,
|
| 47 |
+
incdom_e: Optional[float] = None,
|
| 48 |
+
dom_e: Optional[float] = None,
|
| 49 |
+
num_streamed_chunks: Optional[int] = None,
|
| 50 |
+
streaming_callback: Optional[Callable[[int], None]] = None,
|
| 51 |
+
):
|
| 52 |
+
"""Initializes the Python Jackhmmer wrapper.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
binary_path: The path to the jackhmmer executable.
|
| 56 |
+
database_path: The path to the jackhmmer database (FASTA format).
|
| 57 |
+
n_cpu: The number of CPUs to give Jackhmmer.
|
| 58 |
+
n_iter: The number of Jackhmmer iterations.
|
| 59 |
+
e_value: The E-value, see Jackhmmer docs for more details.
|
| 60 |
+
z_value: The Z-value, see Jackhmmer docs for more details.
|
| 61 |
+
get_tblout: Whether to save tblout string.
|
| 62 |
+
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
|
| 63 |
+
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
|
| 64 |
+
filter_f3: Forward pre-filter, set to >1.0 to turn off.
|
| 65 |
+
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
|
| 66 |
+
round.
|
| 67 |
+
dom_e: Domain e-value criteria for inclusion in tblout.
|
| 68 |
+
num_streamed_chunks: Number of database chunks to stream over.
|
| 69 |
+
streaming_callback: Callback function run after each chunk iteration with
|
| 70 |
+
the iteration number as argument.
|
| 71 |
+
"""
|
| 72 |
+
self.binary_path = binary_path
|
| 73 |
+
self.database_path = database_path
|
| 74 |
+
self.num_streamed_chunks = num_streamed_chunks
|
| 75 |
+
|
| 76 |
+
if (
|
| 77 |
+
not os.path.exists(self.database_path)
|
| 78 |
+
and num_streamed_chunks is None
|
| 79 |
+
):
|
| 80 |
+
logging.error("Could not find Jackhmmer database %s", database_path)
|
| 81 |
+
raise ValueError(
|
| 82 |
+
f"Could not find Jackhmmer database {database_path}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.n_cpu = n_cpu
|
| 86 |
+
self.n_iter = n_iter
|
| 87 |
+
self.e_value = e_value
|
| 88 |
+
self.z_value = z_value
|
| 89 |
+
self.filter_f1 = filter_f1
|
| 90 |
+
self.filter_f2 = filter_f2
|
| 91 |
+
self.filter_f3 = filter_f3
|
| 92 |
+
self.seq_limit = seq_limit
|
| 93 |
+
self.incdom_e = incdom_e
|
| 94 |
+
self.dom_e = dom_e
|
| 95 |
+
self.get_tblout = get_tblout
|
| 96 |
+
self.streaming_callback = streaming_callback
|
| 97 |
+
|
| 98 |
+
def _query_chunk(
|
| 99 |
+
self,
|
| 100 |
+
input_fasta_path: str,
|
| 101 |
+
database_path: str,
|
| 102 |
+
max_sequences: Optional[int] = None
|
| 103 |
+
) -> Mapping[str, Any]:
|
| 104 |
+
"""Queries the database chunk using Jackhmmer."""
|
| 105 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 106 |
+
sto_path = os.path.join(query_tmp_dir, "output.sto")
|
| 107 |
+
|
| 108 |
+
# The F1/F2/F3 are the expected proportion to pass each of the filtering
|
| 109 |
+
# stages (which get progressively more expensive), reducing these
|
| 110 |
+
# speeds up the pipeline at the expensive of sensitivity. They are
|
| 111 |
+
# currently set very low to make querying Mgnify run in a reasonable
|
| 112 |
+
# amount of time.
|
| 113 |
+
cmd_flags = [
|
| 114 |
+
# Don't pollute stdout with Jackhmmer output.
|
| 115 |
+
"-o",
|
| 116 |
+
"/dev/null",
|
| 117 |
+
"-A",
|
| 118 |
+
sto_path,
|
| 119 |
+
"--noali",
|
| 120 |
+
"--F1",
|
| 121 |
+
str(self.filter_f1),
|
| 122 |
+
"--F2",
|
| 123 |
+
str(self.filter_f2),
|
| 124 |
+
"--F3",
|
| 125 |
+
str(self.filter_f3),
|
| 126 |
+
# "--seq_limit",
|
| 127 |
+
# str(self.seq_limit),
|
| 128 |
+
"--incE",
|
| 129 |
+
str(self.e_value),
|
| 130 |
+
# Report only sequences with E-values <= x in per-sequence output.
|
| 131 |
+
"-E",
|
| 132 |
+
str(self.e_value),
|
| 133 |
+
"--cpu",
|
| 134 |
+
str(self.n_cpu),
|
| 135 |
+
"-N",
|
| 136 |
+
str(self.n_iter),
|
| 137 |
+
]
|
| 138 |
+
if self.get_tblout:
|
| 139 |
+
tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
|
| 140 |
+
cmd_flags.extend(["--tblout", tblout_path])
|
| 141 |
+
|
| 142 |
+
if self.z_value:
|
| 143 |
+
cmd_flags.extend(["-Z", str(self.z_value)])
|
| 144 |
+
|
| 145 |
+
if self.dom_e is not None:
|
| 146 |
+
cmd_flags.extend(["--domE", str(self.dom_e)])
|
| 147 |
+
|
| 148 |
+
if self.incdom_e is not None:
|
| 149 |
+
cmd_flags.extend(["--incdomE", str(self.incdom_e)])
|
| 150 |
+
|
| 151 |
+
cmd = (
|
| 152 |
+
[self.binary_path]
|
| 153 |
+
+ cmd_flags
|
| 154 |
+
+ [input_fasta_path, database_path]
|
| 155 |
+
)
|
| 156 |
+
# print(cmd)
|
| 157 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 158 |
+
process = subprocess.Popen(
|
| 159 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 160 |
+
)
|
| 161 |
+
with utils.timing(
|
| 162 |
+
f"Jackhmmer ({os.path.basename(database_path)}) query"
|
| 163 |
+
):
|
| 164 |
+
_, stderr = process.communicate()
|
| 165 |
+
retcode = process.wait()
|
| 166 |
+
|
| 167 |
+
if retcode:
|
| 168 |
+
raise RuntimeError(
|
| 169 |
+
"Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Get e-values for each target name
|
| 173 |
+
tbl = ""
|
| 174 |
+
if self.get_tblout:
|
| 175 |
+
with open(tblout_path) as f:
|
| 176 |
+
tbl = f.read()
|
| 177 |
+
|
| 178 |
+
if (max_sequences is None):
|
| 179 |
+
with open(sto_path) as f:
|
| 180 |
+
sto = f.read()
|
| 181 |
+
else:
|
| 182 |
+
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
|
| 183 |
+
|
| 184 |
+
raw_output = dict(
|
| 185 |
+
sto=sto,
|
| 186 |
+
tbl=tbl,
|
| 187 |
+
stderr=stderr,
|
| 188 |
+
n_iter=self.n_iter,
|
| 189 |
+
e_value=self.e_value,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return raw_output
|
| 193 |
+
|
| 194 |
+
def query(self,
|
| 195 |
+
input_fasta_path: str,
|
| 196 |
+
max_sequences: Optional[int] = None
|
| 197 |
+
) -> Sequence[Sequence[Mapping[str, Any]]]:
|
| 198 |
+
return self.query_multiple([input_fasta_path], max_sequences)
|
| 199 |
+
|
| 200 |
+
def query_multiple(self,
|
| 201 |
+
input_fasta_paths: Sequence[str],
|
| 202 |
+
max_sequences: Optional[int] = None
|
| 203 |
+
) -> Sequence[Sequence[Mapping[str, Any]]]:
|
| 204 |
+
"""Queries the database using Jackhmmer."""
|
| 205 |
+
if self.num_streamed_chunks is None:
|
| 206 |
+
single_chunk_results = []
|
| 207 |
+
for input_fasta_path in input_fasta_paths:
|
| 208 |
+
single_chunk_result = self._query_chunk(
|
| 209 |
+
input_fasta_path, self.database_path, max_sequences,
|
| 210 |
+
)
|
| 211 |
+
single_chunk_results.append(single_chunk_result)
|
| 212 |
+
return single_chunk_results
|
| 213 |
+
|
| 214 |
+
db_basename = os.path.basename(self.database_path)
|
| 215 |
+
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
|
| 216 |
+
db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
|
| 217 |
+
|
| 218 |
+
# Remove existing files to prevent OOM
|
| 219 |
+
for f in glob.glob(db_local_chunk("[0-9]*")):
|
| 220 |
+
try:
|
| 221 |
+
os.remove(f)
|
| 222 |
+
except OSError:
|
| 223 |
+
print(f"OSError while deleting {f}")
|
| 224 |
+
|
| 225 |
+
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
|
| 226 |
+
with futures.ThreadPoolExecutor(max_workers=2) as executor:
|
| 227 |
+
chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
|
| 228 |
+
for i in range(1, self.num_streamed_chunks + 1):
|
| 229 |
+
# Copy the chunk locally
|
| 230 |
+
if i == 1:
|
| 231 |
+
future = executor.submit(
|
| 232 |
+
request.urlretrieve,
|
| 233 |
+
db_remote_chunk(i),
|
| 234 |
+
db_local_chunk(i),
|
| 235 |
+
)
|
| 236 |
+
if i < self.num_streamed_chunks:
|
| 237 |
+
next_future = executor.submit(
|
| 238 |
+
request.urlretrieve,
|
| 239 |
+
db_remote_chunk(i + 1),
|
| 240 |
+
db_local_chunk(i + 1),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Run Jackhmmer with the chunk
|
| 244 |
+
future.result()
|
| 245 |
+
for fasta_idx, input_fasta_path in enumerate(input_fasta_paths):
|
| 246 |
+
chunked_outputs[fasta_idx].append(
|
| 247 |
+
self._query_chunk(
|
| 248 |
+
input_fasta_path,
|
| 249 |
+
db_local_chunk(i),
|
| 250 |
+
max_sequences
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Remove the local copy of the chunk
|
| 255 |
+
os.remove(db_local_chunk(i))
|
| 256 |
+
# Do not set next_future for the last chunk so that this works
|
| 257 |
+
# even for databases with only 1 chunk
|
| 258 |
+
if (i < self.num_streamed_chunks):
|
| 259 |
+
future = next_future
|
| 260 |
+
if self.streaming_callback:
|
| 261 |
+
self.streaming_callback(i)
|
| 262 |
+
return chunked_outputs
|
PhysDock/data/tools/kalign.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""A Python wrapper for Kalign."""
|
| 17 |
+
import os
|
| 18 |
+
import subprocess
|
| 19 |
+
from typing import Sequence
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
from . import utils
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _to_a3m(sequences: Sequence[str]) -> str:
|
| 26 |
+
"""Converts sequences to an a3m file."""
|
| 27 |
+
names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
|
| 28 |
+
a3m = []
|
| 29 |
+
for sequence, name in zip(sequences, names):
|
| 30 |
+
a3m.append(u">" + name + u"\n")
|
| 31 |
+
a3m.append(sequence + u"\n")
|
| 32 |
+
return "".join(a3m)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Kalign:
|
| 36 |
+
"""Python wrapper of the Kalign binary."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, *, binary_path: str):
|
| 39 |
+
"""Initializes the Python Kalign wrapper.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
binary_path: The path to the Kalign binary.
|
| 43 |
+
|
| 44 |
+
Raises:
|
| 45 |
+
RuntimeError: If Kalign binary not found within the path.
|
| 46 |
+
"""
|
| 47 |
+
self.binary_path = binary_path
|
| 48 |
+
|
| 49 |
+
def align(self, sequences: Sequence[str]) -> str:
|
| 50 |
+
"""Aligns the sequences and returns the alignment in A3M string.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
sequences: A list of query sequence strings. The sequences have to be at
|
| 54 |
+
least 6 residues long (Kalign requires this). Note that the order in
|
| 55 |
+
which you give the sequences might alter the output slightly as
|
| 56 |
+
different alignment tree might get constructed.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
A string with the alignment in a3m format.
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
RuntimeError: If Kalign fails.
|
| 63 |
+
ValueError: If any of the sequences is less than 6 residues long.
|
| 64 |
+
"""
|
| 65 |
+
logging.info("Aligning %d sequences", len(sequences))
|
| 66 |
+
|
| 67 |
+
for s in sequences:
|
| 68 |
+
if len(s) < 6:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"Kalign requires all sequences to be at least 6 "
|
| 71 |
+
"residues long. Got %s (%d residues)." % (s, len(s))
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 75 |
+
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
|
| 76 |
+
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
|
| 77 |
+
|
| 78 |
+
with open(input_fasta_path, "w") as f:
|
| 79 |
+
f.write(_to_a3m(sequences))
|
| 80 |
+
|
| 81 |
+
cmd = [
|
| 82 |
+
self.binary_path,
|
| 83 |
+
"-i",
|
| 84 |
+
input_fasta_path,
|
| 85 |
+
"-o",
|
| 86 |
+
output_a3m_path,
|
| 87 |
+
"-format",
|
| 88 |
+
"fasta",
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 92 |
+
process = subprocess.Popen(
|
| 93 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
with utils.timing("Kalign query"):
|
| 97 |
+
stdout, stderr = process.communicate()
|
| 98 |
+
retcode = process.wait()
|
| 99 |
+
logging.info(
|
| 100 |
+
"Kalign stdout:\n%s\n\nstderr:\n%s\n",
|
| 101 |
+
stdout.decode("utf-8"),
|
| 102 |
+
stderr.decode("utf-8"),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if retcode:
|
| 106 |
+
raise RuntimeError(
|
| 107 |
+
"Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
|
| 108 |
+
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
with open(output_a3m_path) as f:
|
| 112 |
+
a3m = f.read()
|
| 113 |
+
|
| 114 |
+
return a3m
|
PhysDock/data/tools/mmcif_parsing.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Parses the mmCIF file format."""
|
| 17 |
+
import collections
|
| 18 |
+
import dataclasses
|
| 19 |
+
import functools
|
| 20 |
+
import io
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import os
|
| 24 |
+
from typing import Any, Mapping, Optional, Sequence, Tuple
|
| 25 |
+
import numpy as np
|
| 26 |
+
from Bio import PDB
|
| 27 |
+
|
| 28 |
+
from . import PDBData
|
| 29 |
+
|
| 30 |
+
@dataclasses.dataclass
|
| 31 |
+
class residue_constants:
|
| 32 |
+
atom_type_num = 37
|
| 33 |
+
atom_types = ["N", "CA", "C", "CB", "O", "CG", "CG1", "CG2", "OG", "OG1", "SG", "CD", "CD1", "CD2", "ND1", "ND2",
|
| 34 |
+
"OD1",
|
| 35 |
+
"OD2", "SD", "CE", "CE1", "CE2", "CE3", "NE", "NE1", "NE2", "OE1", "OE2", "CH2", "NH1", "NH2", "OH",
|
| 36 |
+
"CZ",
|
| 37 |
+
"CZ2", "CZ3", "NZ", "OXT", ]
|
| 38 |
+
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
"""General-purpose errors used throughout the data pipeline"""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Error(Exception):
|
| 45 |
+
"""Base class for exceptions."""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MultipleChainsError(Error):
|
| 49 |
+
"""An error indicating that multiple chains were found for a given ID."""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Type aliases:
|
| 53 |
+
ChainId = str
|
| 54 |
+
PdbHeader = Mapping[str, Any]
|
| 55 |
+
PdbStructure = PDB.Structure.Structure
|
| 56 |
+
SeqRes = str
|
| 57 |
+
MmCIFDict = Mapping[str, Sequence[str]]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclasses.dataclass(frozen=True)
|
| 61 |
+
class Monomer:
|
| 62 |
+
id: str
|
| 63 |
+
num: int
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Note - mmCIF format provides no guarantees on the type of author-assigned
|
| 67 |
+
# sequence numbers. They need not be integers.
|
| 68 |
+
@dataclasses.dataclass(frozen=True)
|
| 69 |
+
class AtomSite:
|
| 70 |
+
residue_name: str
|
| 71 |
+
author_chain_id: str
|
| 72 |
+
mmcif_chain_id: str
|
| 73 |
+
author_seq_num: str
|
| 74 |
+
mmcif_seq_num: int
|
| 75 |
+
insertion_code: str
|
| 76 |
+
hetatm_atom: str
|
| 77 |
+
model_num: int
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Used to map SEQRES index to a residue in the structure.
|
| 81 |
+
@dataclasses.dataclass(frozen=True)
|
| 82 |
+
class ResiduePosition:
|
| 83 |
+
chain_id: str
|
| 84 |
+
residue_number: int
|
| 85 |
+
insertion_code: str
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclasses.dataclass(frozen=True)
|
| 89 |
+
class ResidueAtPosition:
|
| 90 |
+
position: Optional[ResiduePosition]
|
| 91 |
+
name: str
|
| 92 |
+
is_missing: bool
|
| 93 |
+
hetflag: str
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclasses.dataclass(frozen=True)
|
| 97 |
+
class MmcifObject:
|
| 98 |
+
"""Representation of a parsed mmCIF file.
|
| 99 |
+
|
| 100 |
+
Contains:
|
| 101 |
+
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
|
| 102 |
+
files being processed.
|
| 103 |
+
header: Biopython header.
|
| 104 |
+
structure: Biopython structure.
|
| 105 |
+
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
|
| 106 |
+
{'A': 'ABCDEFG'}
|
| 107 |
+
seqres_to_structure: Dict; for each chain_id contains a mapping between
|
| 108 |
+
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
|
| 109 |
+
1: ResidueAtPosition,
|
| 110 |
+
...}}
|
| 111 |
+
raw_string: The raw string used to construct the MmcifObject.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
file_id: str
|
| 115 |
+
header: PdbHeader
|
| 116 |
+
structure: PdbStructure
|
| 117 |
+
chain_to_seqres: Mapping[ChainId, SeqRes]
|
| 118 |
+
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
|
| 119 |
+
raw_string: Any
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclasses.dataclass(frozen=True)
|
| 123 |
+
class ParsingResult:
|
| 124 |
+
"""Returned by the parse function.
|
| 125 |
+
|
| 126 |
+
Contains:
|
| 127 |
+
mmcif_object: A MmcifObject, may be None if no chain could be successfully
|
| 128 |
+
parsed.
|
| 129 |
+
errors: A dict mapping (file_id, chain_id) to any exception generated.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
mmcif_object: Optional[MmcifObject]
|
| 133 |
+
errors: Mapping[Tuple[str, str], Any]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ParseError(Exception):
|
| 137 |
+
"""An error indicating that an mmCIF file could not be parsed."""
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def mmcif_loop_to_list(
|
| 141 |
+
prefix: str, parsed_info: MmCIFDict
|
| 142 |
+
) -> Sequence[Mapping[str, str]]:
|
| 143 |
+
"""Extracts loop associated with a prefix from mmCIF data as a list.
|
| 144 |
+
|
| 145 |
+
Reference for loop_ in mmCIF:
|
| 146 |
+
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
prefix: Prefix shared by each of the data items in the loop.
|
| 150 |
+
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
|
| 151 |
+
_entity_poly_seq.mon_id. Should include the trailing period.
|
| 152 |
+
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
|
| 153 |
+
parser.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
|
| 157 |
+
"""
|
| 158 |
+
cols = []
|
| 159 |
+
data = []
|
| 160 |
+
for key, value in parsed_info.items():
|
| 161 |
+
if key.startswith(prefix):
|
| 162 |
+
cols.append(key)
|
| 163 |
+
data.append(value)
|
| 164 |
+
|
| 165 |
+
assert all([len(xs) == len(data[0]) for xs in data]), (
|
| 166 |
+
"mmCIF error: Not all loops are the same length: %s" % cols
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return [dict(zip(cols, xs)) for xs in zip(*data)]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def mmcif_loop_to_dict(
|
| 173 |
+
prefix: str,
|
| 174 |
+
index: str,
|
| 175 |
+
parsed_info: MmCIFDict,
|
| 176 |
+
) -> Mapping[str, Mapping[str, str]]:
|
| 177 |
+
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
prefix: Prefix shared by each of the data items in the loop.
|
| 181 |
+
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
|
| 182 |
+
_entity_poly_seq.mon_id. Should include the trailing period.
|
| 183 |
+
index: Which item of loop data should serve as the key.
|
| 184 |
+
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
|
| 185 |
+
parser.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
|
| 189 |
+
indexed by the index column.
|
| 190 |
+
"""
|
| 191 |
+
entries = mmcif_loop_to_list(prefix, parsed_info)
|
| 192 |
+
return {entry[index]: entry for entry in entries}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@functools.lru_cache(16, typed=False)
|
| 196 |
+
def parse(
|
| 197 |
+
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True
|
| 198 |
+
) -> ParsingResult:
|
| 199 |
+
"""Entry point, parses an mmcif_string.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
file_id: A string identifier for this file. Should be unique within the
|
| 203 |
+
collection of files being processed.
|
| 204 |
+
mmcif_string: Contents of an mmCIF file.
|
| 205 |
+
catch_all_errors: If True, all exceptions are caught and error messages are
|
| 206 |
+
returned as part of the ParsingResult. If False exceptions will be allowed
|
| 207 |
+
to propagate.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
A ParsingResult.
|
| 211 |
+
"""
|
| 212 |
+
errors = {}
|
| 213 |
+
try:
|
| 214 |
+
parser = PDB.MMCIFParser(QUIET=True)
|
| 215 |
+
handle = io.StringIO(mmcif_string)
|
| 216 |
+
full_structure = parser.get_structure("", handle)
|
| 217 |
+
first_model_structure = _get_first_model(full_structure)
|
| 218 |
+
# Extract the _mmcif_dict from the parser, which contains useful fields not
|
| 219 |
+
# reflected in the Biopython structure.
|
| 220 |
+
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
|
| 221 |
+
|
| 222 |
+
# Ensure all values are lists, even if singletons.
|
| 223 |
+
for key, value in parsed_info.items():
|
| 224 |
+
if not isinstance(value, list):
|
| 225 |
+
parsed_info[key] = [value]
|
| 226 |
+
|
| 227 |
+
header = _get_header(parsed_info)
|
| 228 |
+
|
| 229 |
+
# Determine the protein chains, and their start numbers according to the
|
| 230 |
+
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
|
| 231 |
+
valid_chains = _get_protein_chains(parsed_info=parsed_info)
|
| 232 |
+
if not valid_chains:
|
| 233 |
+
return ParsingResult(
|
| 234 |
+
None, {(file_id, ""): "No protein chains found in this file."}
|
| 235 |
+
)
|
| 236 |
+
seq_start_num = {
|
| 237 |
+
chain_id: min([monomer.num for monomer in seq])
|
| 238 |
+
for chain_id, seq in valid_chains.items()
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
# Loop over the atoms for which we have coordinates. Populate two mappings:
|
| 242 |
+
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
|
| 243 |
+
# the authors / Biopython).
|
| 244 |
+
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
|
| 245 |
+
mmcif_to_author_chain_id = {}
|
| 246 |
+
seq_to_structure_mappings = {}
|
| 247 |
+
for atom in _get_atom_site_list(parsed_info):
|
| 248 |
+
if atom.model_num != "1":
|
| 249 |
+
# We only process the first model at the moment.
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
|
| 253 |
+
|
| 254 |
+
if atom.mmcif_chain_id in valid_chains:
|
| 255 |
+
hetflag = " "
|
| 256 |
+
if atom.hetatm_atom == "HETATM":
|
| 257 |
+
# Water atoms are assigned a special hetflag of W in Biopython. We
|
| 258 |
+
# need to do the same, so that this hetflag can be used to fetch
|
| 259 |
+
# a residue from the Biopython structure by id.
|
| 260 |
+
if atom.residue_name in ("HOH", "WAT"):
|
| 261 |
+
hetflag = "W"
|
| 262 |
+
else:
|
| 263 |
+
hetflag = "H_" + atom.residue_name
|
| 264 |
+
insertion_code = atom.insertion_code
|
| 265 |
+
if not _is_set(atom.insertion_code):
|
| 266 |
+
insertion_code = " "
|
| 267 |
+
position = ResiduePosition(
|
| 268 |
+
chain_id=atom.author_chain_id,
|
| 269 |
+
residue_number=int(atom.author_seq_num),
|
| 270 |
+
insertion_code=insertion_code,
|
| 271 |
+
)
|
| 272 |
+
seq_idx = (
|
| 273 |
+
int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
|
| 274 |
+
)
|
| 275 |
+
current = seq_to_structure_mappings.get(
|
| 276 |
+
atom.author_chain_id, {}
|
| 277 |
+
)
|
| 278 |
+
current[seq_idx] = ResidueAtPosition(
|
| 279 |
+
position=position,
|
| 280 |
+
name=atom.residue_name,
|
| 281 |
+
is_missing=False,
|
| 282 |
+
hetflag=hetflag,
|
| 283 |
+
)
|
| 284 |
+
seq_to_structure_mappings[atom.author_chain_id] = current
|
| 285 |
+
|
| 286 |
+
# Add missing residue information to seq_to_structure_mappings.
|
| 287 |
+
for chain_id, seq_info in valid_chains.items():
|
| 288 |
+
author_chain = mmcif_to_author_chain_id[chain_id]
|
| 289 |
+
current_mapping = seq_to_structure_mappings[author_chain]
|
| 290 |
+
for idx, monomer in enumerate(seq_info):
|
| 291 |
+
if idx not in current_mapping:
|
| 292 |
+
current_mapping[idx] = ResidueAtPosition(
|
| 293 |
+
position=None,
|
| 294 |
+
name=monomer.id,
|
| 295 |
+
is_missing=True,
|
| 296 |
+
hetflag=" ",
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
author_chain_to_sequence = {}
|
| 300 |
+
for chain_id, seq_info in valid_chains.items():
|
| 301 |
+
author_chain = mmcif_to_author_chain_id[chain_id]
|
| 302 |
+
seq = []
|
| 303 |
+
for monomer in seq_info:
|
| 304 |
+
code = PDBData.protein_letters_3to1_extended.get(monomer.id, "X")
|
| 305 |
+
seq.append(code if len(code) == 1 else "X")
|
| 306 |
+
seq = "".join(seq)
|
| 307 |
+
author_chain_to_sequence[author_chain] = seq
|
| 308 |
+
|
| 309 |
+
mmcif_object = MmcifObject(
|
| 310 |
+
file_id=file_id,
|
| 311 |
+
header=header,
|
| 312 |
+
structure=first_model_structure,
|
| 313 |
+
chain_to_seqres=author_chain_to_sequence,
|
| 314 |
+
seqres_to_structure=seq_to_structure_mappings,
|
| 315 |
+
raw_string=parsed_info,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
|
| 319 |
+
except Exception as e: # pylint:disable=broad-except
|
| 320 |
+
errors[(file_id, "")] = e
|
| 321 |
+
if not catch_all_errors:
|
| 322 |
+
raise
|
| 323 |
+
return ParsingResult(mmcif_object=None, errors=errors)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _get_first_model(structure: PdbStructure) -> PdbStructure:
|
| 327 |
+
"""Returns the first model in a Biopython structure."""
|
| 328 |
+
return next(structure.get_models())
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def get_release_date(parsed_info: MmCIFDict) -> str:
|
| 335 |
+
"""Returns the oldest revision date."""
|
| 336 |
+
revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
|
| 337 |
+
return min(revision_dates)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
|
| 341 |
+
"""Returns a basic header containing method, release date and resolution."""
|
| 342 |
+
header = {}
|
| 343 |
+
|
| 344 |
+
experiments = mmcif_loop_to_list("_exptl.", parsed_info)
|
| 345 |
+
header["structure_method"] = ",".join(
|
| 346 |
+
[experiment["_exptl.method"].lower() for experiment in experiments]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Note: The release_date here corresponds to the oldest revision. We prefer to
|
| 350 |
+
# use this for dataset filtering over the deposition_date.
|
| 351 |
+
if "_pdbx_audit_revision_history.revision_date" in parsed_info:
|
| 352 |
+
header["release_date"] = get_release_date(parsed_info)
|
| 353 |
+
else:
|
| 354 |
+
logging.warning(
|
| 355 |
+
"Could not determine release_date: %s", parsed_info["_entry.id"]
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
header["resolution"] = 0.00
|
| 359 |
+
for res_key in (
|
| 360 |
+
"_refine.ls_d_res_high",
|
| 361 |
+
"_em_3d_reconstruction.resolution",
|
| 362 |
+
"_reflns.d_resolution_high",
|
| 363 |
+
):
|
| 364 |
+
if res_key in parsed_info:
|
| 365 |
+
try:
|
| 366 |
+
raw_resolution = parsed_info[res_key][0]
|
| 367 |
+
header["resolution"] = float(raw_resolution)
|
| 368 |
+
except ValueError:
|
| 369 |
+
logging.debug(
|
| 370 |
+
"Invalid resolution format: %s", parsed_info[res_key]
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
return header
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
|
| 377 |
+
"""Returns list of atom sites; contains data not present in the structure."""
|
| 378 |
+
return [
|
| 379 |
+
AtomSite(*site)
|
| 380 |
+
for site in zip( # pylint:disable=g-complex-comprehension
|
| 381 |
+
parsed_info["_atom_site.label_comp_id"],
|
| 382 |
+
parsed_info["_atom_site.auth_asym_id"],
|
| 383 |
+
parsed_info["_atom_site.label_asym_id"],
|
| 384 |
+
parsed_info["_atom_site.auth_seq_id"],
|
| 385 |
+
parsed_info["_atom_site.label_seq_id"],
|
| 386 |
+
parsed_info["_atom_site.pdbx_PDB_ins_code"],
|
| 387 |
+
parsed_info["_atom_site.group_PDB"],
|
| 388 |
+
parsed_info["_atom_site.pdbx_PDB_model_num"],
|
| 389 |
+
)
|
| 390 |
+
]
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _get_protein_chains(
|
| 394 |
+
*, parsed_info: Mapping[str, Any]
|
| 395 |
+
) -> Mapping[ChainId, Sequence[Monomer]]:
|
| 396 |
+
"""Extracts polymer information for protein chains only.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
parsed_info: _mmcif_dict produced by the Biopython parser.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
A dict mapping mmcif chain id to a list of Monomers.
|
| 403 |
+
"""
|
| 404 |
+
# Get polymer information for each entity in the structure.
|
| 405 |
+
entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
|
| 406 |
+
|
| 407 |
+
polymers = collections.defaultdict(list)
|
| 408 |
+
for entity_poly_seq in entity_poly_seqs:
|
| 409 |
+
polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
|
| 410 |
+
Monomer(
|
| 411 |
+
id=entity_poly_seq["_entity_poly_seq.mon_id"],
|
| 412 |
+
num=int(entity_poly_seq["_entity_poly_seq.num"]),
|
| 413 |
+
)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Get chemical compositions. Will allow us to identify which of these polymers
|
| 417 |
+
# are proteins.
|
| 418 |
+
chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
|
| 419 |
+
|
| 420 |
+
# Get chains information for each entity. Necessary so that we can return a
|
| 421 |
+
# dict keyed on chain id rather than entity.
|
| 422 |
+
struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
|
| 423 |
+
|
| 424 |
+
entity_to_mmcif_chains = collections.defaultdict(list)
|
| 425 |
+
for struct_asym in struct_asyms:
|
| 426 |
+
chain_id = struct_asym["_struct_asym.id"]
|
| 427 |
+
entity_id = struct_asym["_struct_asym.entity_id"]
|
| 428 |
+
entity_to_mmcif_chains[entity_id].append(chain_id)
|
| 429 |
+
|
| 430 |
+
# Identify and return the valid protein chains.
|
| 431 |
+
valid_chains = {}
|
| 432 |
+
for entity_id, seq_info in polymers.items():
|
| 433 |
+
chain_ids = entity_to_mmcif_chains[entity_id]
|
| 434 |
+
|
| 435 |
+
# Reject polymers without any peptide-like components, such as DNA/RNA.
|
| 436 |
+
if any(
|
| 437 |
+
[
|
| 438 |
+
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
|
| 439 |
+
for monomer in seq_info
|
| 440 |
+
]
|
| 441 |
+
):
|
| 442 |
+
for chain_id in chain_ids:
|
| 443 |
+
valid_chains[chain_id] = seq_info
|
| 444 |
+
return valid_chains
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _is_set(data: str) -> bool:
|
| 448 |
+
"""Returns False if data is a special mmCIF character indicating 'unset'."""
|
| 449 |
+
return data not in (".", "?")
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def get_atom_coords(
|
| 453 |
+
mmcif_object: MmcifObject,
|
| 454 |
+
chain_id: str,
|
| 455 |
+
_zero_center_positions: bool = False
|
| 456 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 457 |
+
# Locate the right chain
|
| 458 |
+
chains = list(mmcif_object.structure.get_chains())
|
| 459 |
+
relevant_chains = [c for c in chains if c.id == chain_id]
|
| 460 |
+
if len(relevant_chains) != 1:
|
| 461 |
+
raise MultipleChainsError(
|
| 462 |
+
f"Expected exactly one chain in structure with id {chain_id}."
|
| 463 |
+
)
|
| 464 |
+
chain = relevant_chains[0]
|
| 465 |
+
|
| 466 |
+
# Extract the coordinates
|
| 467 |
+
num_res = len(mmcif_object.chain_to_seqres[chain_id])
|
| 468 |
+
all_atom_positions = np.zeros(
|
| 469 |
+
[num_res, residue_constants.atom_type_num, 3], dtype=np.float32
|
| 470 |
+
)
|
| 471 |
+
all_atom_mask = np.zeros(
|
| 472 |
+
[num_res, residue_constants.atom_type_num], dtype=np.float32
|
| 473 |
+
)
|
| 474 |
+
for res_index in range(num_res):
|
| 475 |
+
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
|
| 476 |
+
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
|
| 477 |
+
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
|
| 478 |
+
if not res_at_position.is_missing:
|
| 479 |
+
res = chain[
|
| 480 |
+
(
|
| 481 |
+
res_at_position.hetflag,
|
| 482 |
+
res_at_position.position.residue_number,
|
| 483 |
+
res_at_position.position.insertion_code,
|
| 484 |
+
)
|
| 485 |
+
]
|
| 486 |
+
for atom in res.get_atoms():
|
| 487 |
+
atom_name = atom.get_name()
|
| 488 |
+
x, y, z = atom.get_coord()
|
| 489 |
+
if atom_name in residue_constants.atom_order.keys():
|
| 490 |
+
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
|
| 491 |
+
mask[residue_constants.atom_order[atom_name]] = 1.0
|
| 492 |
+
elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
|
| 493 |
+
# Put the coords of the selenium atom in the sulphur column
|
| 494 |
+
pos[residue_constants.atom_order["SD"]] = [x, y, z]
|
| 495 |
+
mask[residue_constants.atom_order["SD"]] = 1.0
|
| 496 |
+
|
| 497 |
+
# Fix naming errors in arginine residues where NH2 is incorrectly
|
| 498 |
+
# assigned to be closer to CD than NH1
|
| 499 |
+
cd = residue_constants.atom_order['CD']
|
| 500 |
+
nh1 = residue_constants.atom_order['NH1']
|
| 501 |
+
nh2 = residue_constants.atom_order['NH2']
|
| 502 |
+
if (
|
| 503 |
+
res.get_resname() == 'ARG' and
|
| 504 |
+
all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
|
| 505 |
+
(np.linalg.norm(pos[nh1] - pos[cd]) >
|
| 506 |
+
np.linalg.norm(pos[nh2] - pos[cd]))
|
| 507 |
+
):
|
| 508 |
+
pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
|
| 509 |
+
mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
|
| 510 |
+
|
| 511 |
+
all_atom_positions[res_index] = pos
|
| 512 |
+
all_atom_mask[res_index] = mask
|
| 513 |
+
|
| 514 |
+
if _zero_center_positions:
|
| 515 |
+
binary_mask = all_atom_mask.astype(bool)
|
| 516 |
+
translation_vec = all_atom_positions[binary_mask].mean(axis=0)
|
| 517 |
+
all_atom_positions[binary_mask] -= translation_vec
|
| 518 |
+
|
| 519 |
+
return all_atom_positions, all_atom_mask
|
PhysDock/data/tools/msa_identifiers.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Utilities for extracting identifiers from MSA sequence descriptions."""
|
| 16 |
+
|
| 17 |
+
import dataclasses
|
| 18 |
+
import re
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Sequences coming from UniProtKB database come in the
|
| 23 |
+
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
|
| 24 |
+
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
|
| 25 |
+
_UNIPROT_PATTERN = re.compile(
|
| 26 |
+
r"""
|
| 27 |
+
^
|
| 28 |
+
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
|
| 29 |
+
(?:tr|sp)
|
| 30 |
+
\|
|
| 31 |
+
# A primary accession number of the UniProtKB entry.
|
| 32 |
+
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
|
| 33 |
+
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
|
| 34 |
+
(?:_\d)?
|
| 35 |
+
\|
|
| 36 |
+
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
|
| 37 |
+
# protein ID code.
|
| 38 |
+
(?:[A-Za-z0-9]+)
|
| 39 |
+
_
|
| 40 |
+
# A mnemonic species identification code.
|
| 41 |
+
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
|
| 42 |
+
# Small BFD uses a final value after an underscore, which we ignore.
|
| 43 |
+
(?:_\d+)?
|
| 44 |
+
$
|
| 45 |
+
""",
|
| 46 |
+
re.VERBOSE,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclasses.dataclass(frozen=True)
|
| 51 |
+
class Identifiers:
|
| 52 |
+
species_id: str = ""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
|
| 56 |
+
"""Gets accession id and species from an msa sequence identifier.
|
| 57 |
+
|
| 58 |
+
The sequence identifier has the format specified by
|
| 59 |
+
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
|
| 60 |
+
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
msa_sequence_identifier: a sequence identifier.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
An `Identifiers` instance with a species_id. These
|
| 67 |
+
can be empty in the case where no identifier was found.
|
| 68 |
+
"""
|
| 69 |
+
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
|
| 70 |
+
if matches:
|
| 71 |
+
return Identifiers(species_id=matches.group("SpeciesIdentifier"))
|
| 72 |
+
return Identifiers()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _extract_sequence_identifier(description: str) -> Optional[str]:
|
| 76 |
+
"""Extracts sequence identifier from description. Returns None if no match."""
|
| 77 |
+
split_description = description.split()
|
| 78 |
+
if split_description:
|
| 79 |
+
return split_description[0].partition("/")[0]
|
| 80 |
+
else:
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_identifiers(description: str) -> Identifiers:
|
| 85 |
+
"""Computes extra MSA features from the description."""
|
| 86 |
+
sequence_identifier = _extract_sequence_identifier(description)
|
| 87 |
+
if sequence_identifier is None:
|
| 88 |
+
return Identifiers()
|
| 89 |
+
else:
|
| 90 |
+
return _parse_sequence_identifier(sequence_identifier)
|
PhysDock/data/tools/msa_pairing.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Pairing logic for multimer data pipeline."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import functools
|
| 19 |
+
import string
|
| 20 |
+
from typing import Any, Dict, Iterable, List, Sequence, Mapping
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import scipy.linalg
|
| 24 |
+
from scipy.linalg import block_diag
|
| 25 |
+
|
| 26 |
+
# TODO: This stuff should probably also be in a config
|
| 27 |
+
|
| 28 |
+
MSA_GAP_IDX = 31
|
| 29 |
+
SEQUENCE_GAP_CUTOFF = 0.5
|
| 30 |
+
SEQUENCE_SIMILARITY_CUTOFF = 0.9
|
| 31 |
+
|
| 32 |
+
MAX_MSA_SIZE = 16384
|
| 33 |
+
|
| 34 |
+
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
|
| 35 |
+
'msa_mask_all_seq': 1,
|
| 36 |
+
'deletion_matrix_all_seq': 0,
|
| 37 |
+
'deletion_matrix_int_all_seq': 0,
|
| 38 |
+
'msa': MSA_GAP_IDX,
|
| 39 |
+
'msa_mask': 1,
|
| 40 |
+
'deletion_matrix': 0,
|
| 41 |
+
'deletion_matrix_int': 0}
|
| 42 |
+
|
| 43 |
+
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix')
|
| 44 |
+
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
|
| 45 |
+
'all_atom_mask', 'seq_mask', 'between_segment_residues',
|
| 46 |
+
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
|
| 47 |
+
'sym_id', 'entity_mask', 'deletion_mean',
|
| 48 |
+
'prediction_atom_mask',
|
| 49 |
+
'literature_positions', 'atom_indices_to_group_indices',
|
| 50 |
+
'rigid_group_default_frame',
|
| 51 |
+
'restype', 'token_index', 'token_exists', 's_mask',
|
| 52 |
+
"token_id_to_centre_atom_id",
|
| 53 |
+
"token_id_to_pseudo_beta_atom_id",
|
| 54 |
+
"token_id_to_chunk_sizes",
|
| 55 |
+
"token_id_to_conformer_id",
|
| 56 |
+
"is_protein",
|
| 57 |
+
"is_dna",
|
| 58 |
+
"is_rna",
|
| 59 |
+
"is_ligand",
|
| 60 |
+
"atom_index",
|
| 61 |
+
"atom_id_to_token_id",
|
| 62 |
+
"ref_space_uid",
|
| 63 |
+
"ref_pos",
|
| 64 |
+
"ref_feat",
|
| 65 |
+
"x_gt",
|
| 66 |
+
"x_exists",
|
| 67 |
+
"b_factors",
|
| 68 |
+
"a_mask"
|
| 69 |
+
|
| 70 |
+
)
|
| 71 |
+
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
|
| 72 |
+
'template_all_atom_mask')
|
| 73 |
+
CHAIN_FEATURES = ('num_alignments', 'seq_length')
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def create_paired_features(
|
| 77 |
+
chains: Iterable[Mapping[str, np.ndarray]],
|
| 78 |
+
) -> List[Mapping[str, np.ndarray]]:
|
| 79 |
+
"""Returns the original chains with paired NUM_SEQ features.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
chains: A list of feature dictionaries for each chain.
|
| 83 |
+
= list(all_chain_features.values())
|
| 84 |
+
Returns:
|
| 85 |
+
A list of feature dictionaries with sequence features including only
|
| 86 |
+
rows to be paired.
|
| 87 |
+
"""
|
| 88 |
+
chains = list(chains)
|
| 89 |
+
chain_keys = chains[0].keys()
|
| 90 |
+
|
| 91 |
+
if len(chains) < 2:
|
| 92 |
+
return chains
|
| 93 |
+
else:
|
| 94 |
+
|
| 95 |
+
updated_chains = []
|
| 96 |
+
paired_chains_to_paired_row_indices = pair_sequences(chains)
|
| 97 |
+
|
| 98 |
+
paired_rows = reorder_paired_rows(
|
| 99 |
+
paired_chains_to_paired_row_indices)
|
| 100 |
+
|
| 101 |
+
for chain_num, chain in enumerate(chains):
|
| 102 |
+
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
|
| 103 |
+
for feature_name in chain_keys:
|
| 104 |
+
if feature_name.endswith('_all_seq'):
|
| 105 |
+
feats_padded = pad_features(chain[feature_name], feature_name)
|
| 106 |
+
|
| 107 |
+
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
|
| 108 |
+
new_chain['num_alignments_all_seq'] = np.asarray(
|
| 109 |
+
len(paired_rows[:, chain_num]))
|
| 110 |
+
|
| 111 |
+
updated_chains.append(new_chain)
|
| 112 |
+
|
| 113 |
+
return updated_chains
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
|
| 117 |
+
"""Add a 'padding' row at the end of the features list.
|
| 118 |
+
|
| 119 |
+
The padding row will be selected as a 'paired' row in the case of partial
|
| 120 |
+
alignment - for the chain that doesn't have paired alignment.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
feature: The feature to be padded.
|
| 124 |
+
feature_name: The name of the feature to be padded.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
The feature with an additional padding row.
|
| 128 |
+
"""
|
| 129 |
+
assert feature.dtype != np.dtype(np.string_)
|
| 130 |
+
if feature_name in ('msa_all_seq',
|
| 131 |
+
'deletion_matrix_all_seq'):
|
| 132 |
+
num_res = feature.shape[1]
|
| 133 |
+
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
|
| 134 |
+
feature.dtype)
|
| 135 |
+
elif feature_name == 'msa_species_identifiers_all_seq':
|
| 136 |
+
padding = [b'']
|
| 137 |
+
else:
|
| 138 |
+
return feature
|
| 139 |
+
feats_padded = np.concatenate([feature, padding], axis=0)
|
| 140 |
+
return feats_padded
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
|
| 144 |
+
"""Makes dataframe with msa features needed for msa pairing."""
|
| 145 |
+
chain_msa = chain_features['msa_all_seq']
|
| 146 |
+
query_seq = chain_msa[0]
|
| 147 |
+
|
| 148 |
+
per_seq_similarity = np.sum(
|
| 149 |
+
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
|
| 150 |
+
per_seq_gap = np.sum(chain_msa == 31, axis=-1) / float(len(query_seq))
|
| 151 |
+
msa_df = pd.DataFrame({
|
| 152 |
+
'msa_species_identifiers':
|
| 153 |
+
chain_features['msa_species_identifiers_all_seq'],
|
| 154 |
+
'msa_row':
|
| 155 |
+
np.arange(len(
|
| 156 |
+
chain_features['msa_species_identifiers_all_seq'])),
|
| 157 |
+
'msa_similarity': per_seq_similarity,
|
| 158 |
+
'gap': per_seq_gap
|
| 159 |
+
})
|
| 160 |
+
return msa_df
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
|
| 164 |
+
"""Creates mapping from species to msa dataframe of that species."""
|
| 165 |
+
species_lookup = {}
|
| 166 |
+
for species, species_df in msa_df.groupby('msa_species_identifiers'):
|
| 167 |
+
species_lookup[species] = species_df
|
| 168 |
+
return species_lookup
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
|
| 172 |
+
) -> List[List[int]]:
|
| 173 |
+
"""Finds MSA sequence pairings across chains based on sequence similarity.
|
| 174 |
+
|
| 175 |
+
Each chain's MSA sequences are first sorted by their sequence similarity to
|
| 176 |
+
their respective target sequence. The sequences are then paired, starting
|
| 177 |
+
from the sequences most similar to their target sequence.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
this_species_msa_dfs: a list of dataframes containing MSA features for
|
| 181 |
+
sequences for a specific species.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
A list of lists, each containing M indices corresponding to paired MSA rows,
|
| 185 |
+
where M is the number of chains.
|
| 186 |
+
"""
|
| 187 |
+
all_paired_msa_rows = []
|
| 188 |
+
|
| 189 |
+
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
|
| 190 |
+
if species_df is not None]
|
| 191 |
+
take_num_seqs = np.min(num_seqs)
|
| 192 |
+
|
| 193 |
+
sort_by_similarity = (
|
| 194 |
+
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
|
| 195 |
+
|
| 196 |
+
for species_df in this_species_msa_dfs:
|
| 197 |
+
if species_df is not None:
|
| 198 |
+
species_df_sorted = sort_by_similarity(species_df)
|
| 199 |
+
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
|
| 200 |
+
else:
|
| 201 |
+
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
|
| 202 |
+
all_paired_msa_rows.append(msa_rows)
|
| 203 |
+
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
|
| 204 |
+
return all_paired_msa_rows
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def pair_sequences(
|
| 208 |
+
examples: List[Mapping[str, np.ndarray]],
|
| 209 |
+
) -> Dict[int, np.ndarray]:
|
| 210 |
+
"""Returns indices for paired MSA sequences across chains."""
|
| 211 |
+
|
| 212 |
+
num_examples = len(examples)
|
| 213 |
+
|
| 214 |
+
all_chain_species_dict = []
|
| 215 |
+
common_species = set()
|
| 216 |
+
for chain_features in examples:
|
| 217 |
+
msa_df = _make_msa_df(chain_features) # """Makes dataframe with msa features needed for msa pairing."""
|
| 218 |
+
species_dict = _create_species_dict(msa_df)
|
| 219 |
+
all_chain_species_dict.append(species_dict)
|
| 220 |
+
common_species.update(set(species_dict))
|
| 221 |
+
|
| 222 |
+
common_species = sorted(common_species)
|
| 223 |
+
common_species.remove(b'') # Remove target sequence species.
|
| 224 |
+
|
| 225 |
+
all_paired_msa_rows = [np.zeros(len(examples), int)]
|
| 226 |
+
|
| 227 |
+
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
|
| 228 |
+
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
|
| 229 |
+
|
| 230 |
+
for species in common_species:
|
| 231 |
+
if not species:
|
| 232 |
+
continue
|
| 233 |
+
this_species_msa_dfs = []
|
| 234 |
+
species_dfs_present = 0
|
| 235 |
+
for species_dict in all_chain_species_dict:
|
| 236 |
+
if species in species_dict:
|
| 237 |
+
this_species_msa_dfs.append(species_dict[species])
|
| 238 |
+
species_dfs_present += 1
|
| 239 |
+
else:
|
| 240 |
+
this_species_msa_dfs.append(None)
|
| 241 |
+
|
| 242 |
+
# Skip species that are present in only one chain.
|
| 243 |
+
if species_dfs_present <= 1:
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
if np.any(
|
| 247 |
+
np.array([len(species_df) for species_df in
|
| 248 |
+
this_species_msa_dfs if
|
| 249 |
+
isinstance(species_df, pd.DataFrame)]) > 600):
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
|
| 253 |
+
all_paired_msa_rows.extend(paired_msa_rows)
|
| 254 |
+
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
|
| 255 |
+
all_paired_msa_rows_dict = {
|
| 256 |
+
num_examples: np.array(paired_msa_rows) for
|
| 257 |
+
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
|
| 258 |
+
}
|
| 259 |
+
return all_paired_msa_rows_dict
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
|
| 263 |
+
) -> np.ndarray:
|
| 264 |
+
"""Creates a list of indices of paired MSA rows across chains.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
|
| 268 |
+
paired indices.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
a list of lists, each containing indices of paired MSA rows across chains.
|
| 272 |
+
The paired-index lists are ordered by:
|
| 273 |
+
1) the number of chains in the paired alignment, i.e, all-chain pairings
|
| 274 |
+
will come first.
|
| 275 |
+
2) e-values
|
| 276 |
+
"""
|
| 277 |
+
all_paired_msa_rows = []
|
| 278 |
+
|
| 279 |
+
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
|
| 280 |
+
paired_rows = all_paired_msa_rows_dict[num_pairings]
|
| 281 |
+
paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
|
| 282 |
+
paired_rows_sort_index = np.argsort(paired_rows_product)
|
| 283 |
+
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
|
| 284 |
+
|
| 285 |
+
return np.array(all_paired_msa_rows)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
|
| 289 |
+
# """Like scipy.linalg.block_diag but with an optional padding value."""
|
| 290 |
+
# ones_arrs = [np.ones_like(x) for x in arrs]
|
| 291 |
+
# off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
|
| 292 |
+
# diag = scipy.linalg.block_diag(*arrs)
|
| 293 |
+
# diag += (off_diag_mask * pad_value).astype(diag.dtype)
|
| 294 |
+
# return diag
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def _correct_post_merged_feats(
|
| 298 |
+
np_example: Mapping[str, np.ndarray],
|
| 299 |
+
np_chains_list: Sequence[Mapping[str, np.ndarray]],
|
| 300 |
+
pair_msa_sequences: bool
|
| 301 |
+
) -> Mapping[str, np.ndarray]:
|
| 302 |
+
"""Adds features that need to be computed/recomputed post merging."""
|
| 303 |
+
np_example['seq_length'] = np.asarray(
|
| 304 |
+
len(np_example['sequence_3'].split("-")),
|
| 305 |
+
dtype=np.int32
|
| 306 |
+
)
|
| 307 |
+
np_example['num_alignments'] = np.asarray(
|
| 308 |
+
np_example['msa'].shape[0],
|
| 309 |
+
dtype=np.int32
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return np_example
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def _pad_templates(chains: Sequence[Mapping[str, np.ndarray]],
|
| 316 |
+
max_templates: int) -> Sequence[Mapping[str, np.ndarray]]:
|
| 317 |
+
"""For each chain pad the number of templates to a fixed size.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
chains: A list of protein chains.
|
| 321 |
+
max_templates: Each chain will be padded to have this many templates.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
The list of chains, updated to have template features padded to
|
| 325 |
+
max_templates.
|
| 326 |
+
"""
|
| 327 |
+
for chain in chains:
|
| 328 |
+
for k, v in chain.items():
|
| 329 |
+
if k in TEMPLATE_FEATURES:
|
| 330 |
+
padding = np.zeros_like(v.shape)
|
| 331 |
+
padding[0] = max_templates - v.shape[0]
|
| 332 |
+
padding = [(0, p) for p in padding]
|
| 333 |
+
chain[k] = np.pad(v, padding, mode='constant')
|
| 334 |
+
return chains
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _merge_features_from_multiple_chains(
|
| 338 |
+
chains: Sequence[Mapping[str, np.ndarray]],
|
| 339 |
+
pair_msa_sequences: bool) -> Mapping[str, np.ndarray]:
|
| 340 |
+
"""Merge features from multiple chains.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
chains: A list of feature dictionaries that we want to merge.
|
| 344 |
+
pair_msa_sequences: Whether to concatenate MSA features along the
|
| 345 |
+
num_res dimension (if True), or to block diagonalize them (if False).
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
A feature dictionary for the merged example.
|
| 349 |
+
"""
|
| 350 |
+
merged_example = {}
|
| 351 |
+
|
| 352 |
+
for feature_name in chains[0]:
|
| 353 |
+
if feature_name == "msa_species_identifiers" or feature_name == "num_alignments_all_seq" or feature_name == "num_alignments":
|
| 354 |
+
continue
|
| 355 |
+
feats = [x[feature_name] for x in chains]
|
| 356 |
+
|
| 357 |
+
feature_name_split = feature_name.split('_all_seq')[0]
|
| 358 |
+
if feature_name_split in MSA_FEATURES:
|
| 359 |
+
merged_example[feature_name] = np.concatenate(feats, axis=1)
|
| 360 |
+
elif feature_name_split == "templ_feat":
|
| 361 |
+
num_templ = feats[0].shape[0]
|
| 362 |
+
total_len = sum([feat.shape[1] for feat in feats])
|
| 363 |
+
out_mat = np.zeros([num_templ, total_len, total_len, 108], dtype=np.float32)
|
| 364 |
+
start = 0
|
| 365 |
+
end = 0
|
| 366 |
+
for feat in feats:
|
| 367 |
+
end += feat.shape[1]
|
| 368 |
+
out_mat[:, start:end, start:end] = feat
|
| 369 |
+
start = end
|
| 370 |
+
merged_example[feature_name] = out_mat
|
| 371 |
+
elif feature_name_split in SEQ_FEATURES:
|
| 372 |
+
merged_example[feature_name] = np.concatenate(feats, axis=0)
|
| 373 |
+
elif feature_name_split in TEMPLATE_FEATURES:
|
| 374 |
+
merged_example[feature_name] = np.concatenate(feats, axis=1)
|
| 375 |
+
elif feature_name_split in CHAIN_FEATURES:
|
| 376 |
+
merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
|
| 377 |
+
else:
|
| 378 |
+
|
| 379 |
+
merged_example[feature_name] = feats[0]
|
| 380 |
+
|
| 381 |
+
return merged_example
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _merge_homomers_dense_msa(
|
| 385 |
+
chains: Iterable[Mapping[str, np.ndarray]]) -> Sequence[Mapping[str, np.ndarray]]:
|
| 386 |
+
"""Merge all identical chains, making the resulting MSA dense.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
chains: An iterable of features for each chain.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
A list of feature dictionaries. All features with the same entity_id
|
| 393 |
+
will be merged - MSA features will be concatenated along the num_res
|
| 394 |
+
dimension - making them dense.
|
| 395 |
+
"""
|
| 396 |
+
entity_chains = collections.defaultdict(list)
|
| 397 |
+
for chain in chains:
|
| 398 |
+
entity_id = chain['entity_id'][0]
|
| 399 |
+
entity_chains[entity_id].append(chain)
|
| 400 |
+
|
| 401 |
+
grouped_chains = []
|
| 402 |
+
for entity_id in sorted(entity_chains):
|
| 403 |
+
chains = entity_chains[entity_id]
|
| 404 |
+
|
| 405 |
+
grouped_chains.append(chains)
|
| 406 |
+
|
| 407 |
+
chains = [
|
| 408 |
+
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
|
| 409 |
+
for chains in grouped_chains]
|
| 410 |
+
return chains
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def _concatenate_paired_and_unpaired_features(
|
| 414 |
+
example: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
|
| 415 |
+
"""Merges paired and block-diagonalised features."""
|
| 416 |
+
features = MSA_FEATURES
|
| 417 |
+
for feature_name in features:
|
| 418 |
+
feat_all_seq_name = feature_name + '_all_seq'
|
| 419 |
+
if feature_name in example and feat_all_seq_name in example:
|
| 420 |
+
feat = example[feature_name]
|
| 421 |
+
feat_all_seq = example[feature_name + '_all_seq']
|
| 422 |
+
|
| 423 |
+
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
|
| 424 |
+
example[feature_name] = merged_feat
|
| 425 |
+
example.pop(feature_name + '_all_seq', None)
|
| 426 |
+
# example['num_alignments_all_seq'] = np.array(example['msa_all_seq'].shape[0],
|
| 427 |
+
# dtype=np.int32)
|
| 428 |
+
|
| 429 |
+
return example
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def merge_chain_features(np_chains_list: List[Mapping[str, np.ndarray]],
|
| 433 |
+
pair_msa_sequences: bool,
|
| 434 |
+
max_templates: int) -> Mapping[str, np.ndarray]:
|
| 435 |
+
"""Merges features for multiple chains to single FeatureDict.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
np_chains_list: List of FeatureDicts for each chain.
|
| 439 |
+
pair_msa_sequences: Whether to merge paired MSAs.
|
| 440 |
+
max_templates: The maximum number of templates to include.
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
Single FeatureDict for entire complex.
|
| 444 |
+
"""
|
| 445 |
+
# print(np_chains_list)
|
| 446 |
+
# for chain in np_chains_list:
|
| 447 |
+
# for k,v in chain.items():
|
| 448 |
+
# try:
|
| 449 |
+
# print(k,v.shape)
|
| 450 |
+
# except:
|
| 451 |
+
# print(k)
|
| 452 |
+
# print("################")
|
| 453 |
+
|
| 454 |
+
np_chains_list = _pad_templates(
|
| 455 |
+
np_chains_list, max_templates=max_templates)
|
| 456 |
+
# group chains
|
| 457 |
+
# print(np_chains_list)
|
| 458 |
+
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
|
| 459 |
+
|
| 460 |
+
np_chains_list = [_concatenate_paired_and_unpaired_features(np_chain) for np_chain in np_chains_list]
|
| 461 |
+
|
| 462 |
+
assert len(np_chains_list) > 0, f"np_chain_list, error, {np_chains_list}"
|
| 463 |
+
np_example = _merge_features_from_multiple_chains(
|
| 464 |
+
np_chains_list, pair_msa_sequences=False)
|
| 465 |
+
|
| 466 |
+
np_example = _correct_post_merged_feats(
|
| 467 |
+
np_example=np_example,
|
| 468 |
+
np_chains_list=np_chains_list,
|
| 469 |
+
pair_msa_sequences=pair_msa_sequences)
|
| 470 |
+
|
| 471 |
+
return np_example
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def deduplicate_unpaired_sequences(
|
| 475 |
+
np_chains: List[Mapping[str, np.ndarray]]) -> List[Mapping[str, np.ndarray]]:
|
| 476 |
+
"""Removes unpaired sequences which duplicate a paired sequence."""
|
| 477 |
+
|
| 478 |
+
feature_names = np_chains[0].keys()
|
| 479 |
+
msa_features = MSA_FEATURES
|
| 480 |
+
|
| 481 |
+
for chain in np_chains:
|
| 482 |
+
# Convert the msa_all_seq numpy array to a tuple for hashing.
|
| 483 |
+
sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
|
| 484 |
+
keep_rows = []
|
| 485 |
+
# Go through unpaired MSA seqs and remove any rows that correspond to the
|
| 486 |
+
# sequences that are already present in the paired MSA.
|
| 487 |
+
for row_num, seq in enumerate(chain['msa']):
|
| 488 |
+
if tuple(seq) not in sequence_set:
|
| 489 |
+
keep_rows.append(row_num)
|
| 490 |
+
if keep_rows is not None:
|
| 491 |
+
for feature_name in feature_names:
|
| 492 |
+
if feature_name in msa_features:
|
| 493 |
+
chain[feature_name] = chain[feature_name][keep_rows]
|
| 494 |
+
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
|
| 495 |
+
|
| 496 |
+
return np_chains
|
PhysDock/data/tools/nhmmer.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Library to run Jackhmmer from Python."""
|
| 17 |
+
|
| 18 |
+
from concurrent import futures
|
| 19 |
+
import glob
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import subprocess
|
| 23 |
+
from typing import Any, Callable, Mapping, Optional, Sequence
|
| 24 |
+
from urllib import request
|
| 25 |
+
|
| 26 |
+
from . import parsers
|
| 27 |
+
from . import utils
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Nhmmer:
|
| 31 |
+
"""Python wrapper of the Jackhmmer binary."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
binary_path: str,
|
| 37 |
+
database_path: str,
|
| 38 |
+
n_cpu: int = 8, # --cpu
|
| 39 |
+
e_value: float = 0.001, # -E --incE | Z-value is not used in AF3
|
| 40 |
+
filter_f3: float = 0.00005, # --F3 0.02 for sequences shorter than 50 nucleotides
|
| 41 |
+
get_tblout: bool = False,
|
| 42 |
+
incdom_e: Optional[float] = None,
|
| 43 |
+
dom_e: Optional[float] = None,
|
| 44 |
+
num_streamed_chunks: Optional[int] = None,
|
| 45 |
+
streaming_callback: Optional[Callable[[int], None]] = None,
|
| 46 |
+
):
|
| 47 |
+
"""Initializes the Python Nhmmer wrapper.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
binary_path: The path to the jackhmmer executable.
|
| 51 |
+
database_path: The path to the jackhmmer database (FASTA format).
|
| 52 |
+
n_cpu: The number of CPUs to give Jackhmmer.
|
| 53 |
+
e_value: The E-value, see Jackhmmer docs for more details.
|
| 54 |
+
get_tblout: Whether to save tblout string.
|
| 55 |
+
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
|
| 56 |
+
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
|
| 57 |
+
filter_f3: Forward pre-filter, set to >1.0 to turn off.
|
| 58 |
+
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
|
| 59 |
+
round.
|
| 60 |
+
dom_e: Domain e-value criteria for inclusion in tblout.
|
| 61 |
+
num_streamed_chunks: Number of database chunks to stream over.
|
| 62 |
+
streaming_callback: Callback function run after each chunk iteration with
|
| 63 |
+
the iteration number as argument.
|
| 64 |
+
"""
|
| 65 |
+
self.binary_path = binary_path
|
| 66 |
+
self.database_path = database_path
|
| 67 |
+
self.num_streamed_chunks = num_streamed_chunks
|
| 68 |
+
|
| 69 |
+
if (
|
| 70 |
+
not os.path.exists(self.database_path)
|
| 71 |
+
and num_streamed_chunks is None
|
| 72 |
+
):
|
| 73 |
+
logging.error("Could not find Jackhmmer database %s", database_path)
|
| 74 |
+
raise ValueError(
|
| 75 |
+
f"Could not find Jackhmmer database {database_path}"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.n_cpu = n_cpu
|
| 79 |
+
self.e_value = e_value
|
| 80 |
+
self.filter_f3 = filter_f3
|
| 81 |
+
self.incdom_e = incdom_e
|
| 82 |
+
self.dom_e = dom_e
|
| 83 |
+
self.get_tblout = get_tblout
|
| 84 |
+
self.streaming_callback = streaming_callback
|
| 85 |
+
|
| 86 |
+
def _query_chunk(
|
| 87 |
+
self,
|
| 88 |
+
input_fasta_path: str,
|
| 89 |
+
database_path: str,
|
| 90 |
+
max_sequences: Optional[int] = None
|
| 91 |
+
) -> Mapping[str, Any]:
|
| 92 |
+
"""Queries the database chunk using Jackhmmer."""
|
| 93 |
+
|
| 94 |
+
with open(input_fasta_path, "r") as f:
|
| 95 |
+
sequences, desc = parsers.parse_fasta(f.read())
|
| 96 |
+
assert len(sequences) == 1, f"Parse Fasta File with only 1 Sequence, but found {len(sequences)}"
|
| 97 |
+
if len(sequences[0]) < 50:
|
| 98 |
+
self.filter_f3 = 0.02
|
| 99 |
+
else:
|
| 100 |
+
self.filter_f3 = 0.00005
|
| 101 |
+
with utils.tmpdir_manager() as query_tmp_dir:
|
| 102 |
+
sto_path = os.path.join(query_tmp_dir, "output.sto")
|
| 103 |
+
|
| 104 |
+
# The F1/F2/F3 are the expected proportion to pass each of the filtering
|
| 105 |
+
# stages (which get progressively more expensive), reducing these
|
| 106 |
+
# speeds up the pipeline at the expensive of sensitivity. They are
|
| 107 |
+
# currently set very low to make querying Mgnify run in a reasonable
|
| 108 |
+
# amount of time.
|
| 109 |
+
cmd_flags = [
|
| 110 |
+
# Don't pollute stdout with Jackhmmer output.
|
| 111 |
+
"-o",
|
| 112 |
+
"/dev/null",
|
| 113 |
+
"-A",
|
| 114 |
+
sto_path,
|
| 115 |
+
"--noali",
|
| 116 |
+
# Report only sequences with E-values <= x in per-sequence output.
|
| 117 |
+
"-E",
|
| 118 |
+
str(self.e_value),
|
| 119 |
+
"--incE",
|
| 120 |
+
str(self.e_value),
|
| 121 |
+
"--rna",
|
| 122 |
+
"--watson",
|
| 123 |
+
"--F3", # Only F3 is used
|
| 124 |
+
str(self.filter_f3),
|
| 125 |
+
"--cpu",
|
| 126 |
+
str(self.n_cpu),
|
| 127 |
+
]
|
| 128 |
+
if self.get_tblout:
|
| 129 |
+
tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
|
| 130 |
+
cmd_flags.extend(["--tblout", tblout_path])
|
| 131 |
+
|
| 132 |
+
if self.dom_e is not None:
|
| 133 |
+
cmd_flags.extend(["--domE", str(self.dom_e)])
|
| 134 |
+
|
| 135 |
+
if self.incdom_e is not None:
|
| 136 |
+
cmd_flags.extend(["--incdomE", str(self.incdom_e)])
|
| 137 |
+
|
| 138 |
+
cmd = (
|
| 139 |
+
[self.binary_path]
|
| 140 |
+
+ cmd_flags
|
| 141 |
+
+ [input_fasta_path, database_path]
|
| 142 |
+
)
|
| 143 |
+
# print(cmd)
|
| 144 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 145 |
+
process = subprocess.Popen(
|
| 146 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 147 |
+
)
|
| 148 |
+
with utils.timing(
|
| 149 |
+
f"Nhmmer ({os.path.basename(database_path)}) query"
|
| 150 |
+
):
|
| 151 |
+
_, stderr = process.communicate()
|
| 152 |
+
retcode = process.wait()
|
| 153 |
+
|
| 154 |
+
if retcode:
|
| 155 |
+
raise RuntimeError(
|
| 156 |
+
"Nhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Get e-values for each target name
|
| 160 |
+
tbl = ""
|
| 161 |
+
if self.get_tblout:
|
| 162 |
+
with open(tblout_path) as f:
|
| 163 |
+
tbl = f.read()
|
| 164 |
+
|
| 165 |
+
if (max_sequences is None):
|
| 166 |
+
with open(sto_path) as f:
|
| 167 |
+
sto = f.read()
|
| 168 |
+
else:
|
| 169 |
+
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
|
| 170 |
+
|
| 171 |
+
raw_output = dict(
|
| 172 |
+
sto=sto,
|
| 173 |
+
tbl=tbl,
|
| 174 |
+
stderr=stderr,
|
| 175 |
+
e_value=self.e_value,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
return raw_output
|
| 179 |
+
|
| 180 |
+
def query(self,
|
| 181 |
+
input_fasta_path: str,
|
| 182 |
+
max_sequences: Optional[int] = None
|
| 183 |
+
) -> Sequence[Sequence[Mapping[str, Any]]]:
|
| 184 |
+
return self.query_multiple([input_fasta_path], max_sequences)
|
| 185 |
+
|
| 186 |
+
def query_multiple(self,
|
| 187 |
+
input_fasta_paths: Sequence[str],
|
| 188 |
+
max_sequences: Optional[int] = None
|
| 189 |
+
) -> Sequence[Sequence[Mapping[str, Any]]]:
|
| 190 |
+
"""Queries the database using Nhmmer."""
|
| 191 |
+
if self.num_streamed_chunks is None:
|
| 192 |
+
single_chunk_results = []
|
| 193 |
+
for input_fasta_path in input_fasta_paths:
|
| 194 |
+
single_chunk_result = self._query_chunk(
|
| 195 |
+
input_fasta_path, self.database_path, max_sequences,
|
| 196 |
+
)
|
| 197 |
+
single_chunk_results.append(single_chunk_result)
|
| 198 |
+
return single_chunk_results
|
| 199 |
+
|
| 200 |
+
db_basename = os.path.basename(self.database_path)
|
| 201 |
+
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
|
| 202 |
+
db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
|
| 203 |
+
|
| 204 |
+
# Remove existing files to prevent OOM
|
| 205 |
+
for f in glob.glob(db_local_chunk("[0-9]*")):
|
| 206 |
+
try:
|
| 207 |
+
os.remove(f)
|
| 208 |
+
except OSError:
|
| 209 |
+
print(f"OSError while deleting {f}")
|
| 210 |
+
|
| 211 |
+
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
|
| 212 |
+
with futures.ThreadPoolExecutor(max_workers=2) as executor:
|
| 213 |
+
chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
|
| 214 |
+
for i in range(1, self.num_streamed_chunks + 1):
|
| 215 |
+
# Copy the chunk locally
|
| 216 |
+
if i == 1:
|
| 217 |
+
future = executor.submit(
|
| 218 |
+
request.urlretrieve,
|
| 219 |
+
db_remote_chunk(i),
|
| 220 |
+
db_local_chunk(i),
|
| 221 |
+
)
|
| 222 |
+
if i < self.num_streamed_chunks:
|
| 223 |
+
next_future = executor.submit(
|
| 224 |
+
request.urlretrieve,
|
| 225 |
+
db_remote_chunk(i + 1),
|
| 226 |
+
db_local_chunk(i + 1),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Run Jackhmmer with the chunk
|
| 230 |
+
future.result()
|
| 231 |
+
for fasta_idx, input_fasta_path in enumerate(input_fasta_paths):
|
| 232 |
+
chunked_outputs[fasta_idx].append(
|
| 233 |
+
self._query_chunk(
|
| 234 |
+
input_fasta_path,
|
| 235 |
+
db_local_chunk(i),
|
| 236 |
+
max_sequences
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Remove the local copy of the chunk
|
| 241 |
+
os.remove(db_local_chunk(i))
|
| 242 |
+
# Do not set next_future for the last chunk so that this works
|
| 243 |
+
# even for databases with only 1 chunk
|
| 244 |
+
if (i < self.num_streamed_chunks):
|
| 245 |
+
future = next_future
|
| 246 |
+
if self.streaming_callback:
|
| 247 |
+
self.streaming_callback(i)
|
| 248 |
+
return chunked_outputs
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == '__main__':
|
| 252 |
+
# pass
|
| 253 |
+
nhmmer = Nhmmer(binary_path="/usr/bin/nhmmer",
|
| 254 |
+
# database_path="/group1/share01/data/alphafold3/rnacentral/v21.0/rnacentral.fasta")
|
| 255 |
+
database_path="/group1/share01/data/alphafold3/rfam/v14.9/Rfam_af3_clustered_all_seqs.fasta")
|
| 256 |
+
out = nhmmer.query(input_fasta_path="./test.fa")
|
| 257 |
+
print(out[0]["sto"])
|
PhysDock/data/tools/parse_msas.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Optional, Sequence, Dict, OrderedDict, Any, Union
|
| 4 |
+
from scipy.sparse import coo_matrix
|
| 5 |
+
|
| 6 |
+
from .parsers import parse_fasta, parse_hhr, parse_stockholm, parse_a3m, parse_hmmsearch_a3m, \
|
| 7 |
+
parse_hmmsearch_sto, Msa, parse_stockholm_file, Msa
|
| 8 |
+
from . import msa_identifiers
|
| 9 |
+
|
| 10 |
+
FeatureDict = Dict[str, Union[np.ndarray, coo_matrix, None, Any]]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_txt(fname):
|
| 14 |
+
with open(fname, "r") as f:
|
| 15 |
+
data = f.read()
|
| 16 |
+
return data
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
amino_acids = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 20 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
|
| 21 |
+
|
| 22 |
+
HHBLITS_AA_TO_AA = {
|
| 23 |
+
"A": "A",
|
| 24 |
+
"B": "D",
|
| 25 |
+
"C": "C",
|
| 26 |
+
"D": "D",
|
| 27 |
+
"E": "E",
|
| 28 |
+
"F": "F",
|
| 29 |
+
"G": "G",
|
| 30 |
+
"H": "H",
|
| 31 |
+
"I": "I",
|
| 32 |
+
"J": "X",
|
| 33 |
+
"K": "K",
|
| 34 |
+
"L": "L",
|
| 35 |
+
"M": "M",
|
| 36 |
+
"N": "N",
|
| 37 |
+
"O": "X",
|
| 38 |
+
"P": "P",
|
| 39 |
+
"Q": "Q",
|
| 40 |
+
"R": "R",
|
| 41 |
+
"S": "S",
|
| 42 |
+
"T": "T",
|
| 43 |
+
"U": "C",
|
| 44 |
+
"V": "V",
|
| 45 |
+
"W": "W",
|
| 46 |
+
"X": "X",
|
| 47 |
+
"Y": "Y",
|
| 48 |
+
"Z": "E",
|
| 49 |
+
"-": "-",
|
| 50 |
+
}
|
| 51 |
+
standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 52 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
|
| 53 |
+
amino_acid_1to3 = {
|
| 54 |
+
"A": "ALA",
|
| 55 |
+
"R": "ARG",
|
| 56 |
+
"N": "ASN",
|
| 57 |
+
"D": "ASP",
|
| 58 |
+
"C": "CYS",
|
| 59 |
+
"Q": "GLN",
|
| 60 |
+
"E": "GLU",
|
| 61 |
+
"G": "GLY",
|
| 62 |
+
"H": "HIS",
|
| 63 |
+
"I": "ILE",
|
| 64 |
+
"L": "LEU",
|
| 65 |
+
"K": "LYS",
|
| 66 |
+
"M": "MET",
|
| 67 |
+
"F": "PHE",
|
| 68 |
+
"P": "PRO",
|
| 69 |
+
"S": "SER",
|
| 70 |
+
"T": "THR",
|
| 71 |
+
"W": "TRP",
|
| 72 |
+
"Y": "TYR",
|
| 73 |
+
"V": "VAL",
|
| 74 |
+
"X": "UNK",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
amino_acid_3to1 = {v: k for k, v in amino_acid_1to3.items()}
|
| 78 |
+
|
| 79 |
+
AA_TO_ID = {
|
| 80 |
+
amino_acid_3to1[ccd]: amino_acids.index(ccd) for ccd in standard_protein
|
| 81 |
+
}
|
| 82 |
+
AA_TO_ID["-"] = 31
|
| 83 |
+
|
| 84 |
+
robon_nucleic_acids = ["A", "G", "C", "U", "N", ]
|
| 85 |
+
|
| 86 |
+
RNA_TO_ID = {ch: robon_nucleic_acids.index(ch) + 21 for ch in robon_nucleic_acids}
|
| 87 |
+
RNA_TO_ID["-"] = 31
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# DEBUG
|
| 91 |
+
# RNA_TO_ID["."] = 31
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def make_msa_features(msas: Sequence[Msa], is_rna=False) -> FeatureDict:
|
| 95 |
+
"""Constructs a feature dict of MSA features."""
|
| 96 |
+
if not msas:
|
| 97 |
+
raise ValueError("At least one MSA must be provided.")
|
| 98 |
+
|
| 99 |
+
int_msa = []
|
| 100 |
+
deletion_matrix = []
|
| 101 |
+
species_ids = []
|
| 102 |
+
seen_sequences = set()
|
| 103 |
+
|
| 104 |
+
for msa_index, msa in enumerate(msas):
|
| 105 |
+
if not msa:
|
| 106 |
+
raise ValueError(f"MSA {msa_index} must contain at least one sequence.")
|
| 107 |
+
for sequence_index, (sequence, msa_deletion_matrix) in enumerate(
|
| 108 |
+
zip(msa.sequences, msa.deletion_matrix)):
|
| 109 |
+
if sequence in seen_sequences:
|
| 110 |
+
continue
|
| 111 |
+
seen_sequences.add(sequence)
|
| 112 |
+
if is_rna:
|
| 113 |
+
int_msa.append(
|
| 114 |
+
[RNA_TO_ID.get(res, RNA_TO_ID["N"]) for res in sequence]
|
| 115 |
+
)
|
| 116 |
+
# deletion_matrix.append([
|
| 117 |
+
# msa_deletion_matrix[id] for id, res in enumerate(sequence)
|
| 118 |
+
# ])
|
| 119 |
+
else:
|
| 120 |
+
int_msa.append(
|
| 121 |
+
[AA_TO_ID[HHBLITS_AA_TO_AA[res]] for res in sequence]
|
| 122 |
+
)
|
| 123 |
+
deletion_matrix.append(msa_deletion_matrix)
|
| 124 |
+
identifiers = msa_identifiers.get_identifiers(
|
| 125 |
+
msa.descriptions[sequence_index]
|
| 126 |
+
)
|
| 127 |
+
species_ids.append(identifiers.species_id.encode("utf-8"))
|
| 128 |
+
features = {}
|
| 129 |
+
features["deletion_matrix"] = np.array(deletion_matrix, dtype=np.int8)
|
| 130 |
+
|
| 131 |
+
features["msa"] = np.array(int_msa, dtype=np.int8)
|
| 132 |
+
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
|
| 133 |
+
return features
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_alignment_dir(
|
| 137 |
+
alignment_dir,
|
| 138 |
+
):
|
| 139 |
+
# MSA Order: uniref90 bfd_uniclust30/bfd_uniref30 mgnify
|
| 140 |
+
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.sto")
|
| 141 |
+
uniprot_out_path = os.path.join(alignment_dir, "uniprot_hits.sto")
|
| 142 |
+
reduced_bfd_out_path = os.path.join(alignment_dir, "reduced_bfd_hits.sto")
|
| 143 |
+
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.sto")
|
| 144 |
+
bfd_uniref30_out_path = os.path.join(alignment_dir, f"bfd_uniref30_hits.a3m")
|
| 145 |
+
bfd_uniclust30_out_path = os.path.join(alignment_dir, f"bfd_uniclust30_hits.a3m")
|
| 146 |
+
rfam_out_path = os.path.join(alignment_dir, f"rfam_hits2.sto")
|
| 147 |
+
rnacentral_out_path = os.path.join(alignment_dir, f"rnacentral_hits.sto")
|
| 148 |
+
nt_out_path = os.path.join(alignment_dir, f"nt_hits.sto")
|
| 149 |
+
|
| 150 |
+
uniref90_msa = None
|
| 151 |
+
bfd_uniclust30_msa = None
|
| 152 |
+
bfd_uniref30_msa = None
|
| 153 |
+
reduced_bfd_msa = None
|
| 154 |
+
mgnify_msa = None
|
| 155 |
+
uniprot_msa = None
|
| 156 |
+
rfam_msa = None
|
| 157 |
+
rnacentral_msa = None
|
| 158 |
+
nt_msa = None
|
| 159 |
+
|
| 160 |
+
if os.path.exists(uniref90_out_path):
|
| 161 |
+
uniref90_msa = parse_stockholm(load_txt(uniref90_out_path))
|
| 162 |
+
|
| 163 |
+
if os.path.exists(bfd_uniclust30_out_path):
|
| 164 |
+
bfd_uniclust30_msa = parse_a3m(load_txt(bfd_uniclust30_out_path))
|
| 165 |
+
if os.path.exists(bfd_uniref30_out_path):
|
| 166 |
+
bfd_uniref30_msa = parse_a3m(load_txt(bfd_uniref30_out_path))
|
| 167 |
+
if os.path.exists(reduced_bfd_out_path):
|
| 168 |
+
reduced_bfd_msa = parse_stockholm(load_txt(reduced_bfd_out_path))
|
| 169 |
+
if os.path.exists(mgnify_out_path):
|
| 170 |
+
mgnify_msa = parse_stockholm(load_txt(mgnify_out_path))
|
| 171 |
+
|
| 172 |
+
if os.path.exists(uniprot_out_path):
|
| 173 |
+
uniprot_msa = parse_stockholm(load_txt(uniprot_out_path))
|
| 174 |
+
|
| 175 |
+
if os.path.exists(rfam_out_path):
|
| 176 |
+
# rfam_msa = parse_stockholm(load_txt(rfam_out_path))
|
| 177 |
+
rfam_msa = parse_stockholm_file(rfam_out_path)
|
| 178 |
+
|
| 179 |
+
if os.path.exists(rnacentral_out_path):
|
| 180 |
+
# rnacentral_msa = parse_stockholm(load_txt(rnacentral_out_path))
|
| 181 |
+
rnacentral_msa = parse_stockholm_file(rnacentral_out_path)
|
| 182 |
+
|
| 183 |
+
if os.path.exists(nt_out_path):
|
| 184 |
+
# nt_msa = parse_stockholm(load_txt(nt_out_path))
|
| 185 |
+
nt_msa = parse_stockholm_file(nt_out_path)
|
| 186 |
+
|
| 187 |
+
protein_msas = [uniref90_msa, bfd_uniclust30_msa, bfd_uniref30_msa, reduced_bfd_msa, mgnify_msa]
|
| 188 |
+
uniprot_msas = [uniprot_msa]
|
| 189 |
+
rna_msas = [rfam_msa, rnacentral_msa, nt_msa]
|
| 190 |
+
protein_msas = [i for i in protein_msas if i is not None]
|
| 191 |
+
uniprot_msas = [i for i in uniprot_msas if i is not None]
|
| 192 |
+
rna_msas = [i for i in rna_msas if i is not None]
|
| 193 |
+
output = dict()
|
| 194 |
+
if len(uniprot_msas) > 0:
|
| 195 |
+
uniprot_msa_features = make_msa_features(uniprot_msas)
|
| 196 |
+
output["msa_all_seq"] = uniprot_msa_features.pop("msa")
|
| 197 |
+
output["deletion_matrix_all_seq"] = uniprot_msa_features.pop("deletion_matrix")
|
| 198 |
+
output["msa_species_identifiers_all_seq"] = uniprot_msa_features.pop("msa_species_identifiers")
|
| 199 |
+
if len(protein_msas) > 0:
|
| 200 |
+
msa_features = make_msa_features(protein_msas)
|
| 201 |
+
output["msa"] = msa_features.pop("msa")
|
| 202 |
+
output["deletion_matrix"] = msa_features.pop("deletion_matrix")
|
| 203 |
+
output["msa_species_identifiers"] = msa_features.pop("msa_species_identifiers")
|
| 204 |
+
|
| 205 |
+
# TODO: DEBUG parse rna sto and
|
| 206 |
+
if len(rna_msas) > 0:
|
| 207 |
+
assert len(protein_msas) == 0
|
| 208 |
+
msa_features = make_msa_features(rna_msas, is_rna=True)
|
| 209 |
+
output["msa"] = msa_features.pop("msa")
|
| 210 |
+
output["deletion_matrix"] = msa_features.pop("deletion_matrix")
|
| 211 |
+
output["msa_species_identifiers"] = msa_features.pop("msa_species_identifiers")
|
| 212 |
+
|
| 213 |
+
return output
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def parse_protein_alignment_dir(alignment_dir):
|
| 217 |
+
# MSA Order: uniref90 bfd_uniclust30/bfd_uniref30 mgnify
|
| 218 |
+
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.sto")
|
| 219 |
+
reduced_bfd_out_path = os.path.join(alignment_dir, "reduced_bfd_hits.sto")
|
| 220 |
+
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.sto")
|
| 221 |
+
bfd_uniref30_out_path = os.path.join(alignment_dir, f"bfd_uniref_hits.a3m")
|
| 222 |
+
bfd_uniclust30_out_path = os.path.join(alignment_dir, f"bfd_uniclust30_hits.a3m")
|
| 223 |
+
|
| 224 |
+
uniref90_msa = None
|
| 225 |
+
bfd_uniclust30_msa = None
|
| 226 |
+
bfd_uniref30_msa = None
|
| 227 |
+
reduced_bfd_msa = None
|
| 228 |
+
mgnify_msa = None
|
| 229 |
+
|
| 230 |
+
if os.path.exists(uniref90_out_path):
|
| 231 |
+
uniref90_msa = parse_stockholm(load_txt(uniref90_out_path))
|
| 232 |
+
|
| 233 |
+
if os.path.exists(bfd_uniclust30_out_path):
|
| 234 |
+
bfd_uniclust30_msa = parse_a3m(load_txt(bfd_uniclust30_out_path))
|
| 235 |
+
if os.path.exists(bfd_uniref30_out_path):
|
| 236 |
+
bfd_uniref30_msa = parse_a3m(load_txt(bfd_uniref30_out_path))
|
| 237 |
+
if os.path.exists(reduced_bfd_out_path):
|
| 238 |
+
reduced_bfd_msa = parse_stockholm(load_txt(reduced_bfd_out_path))
|
| 239 |
+
if os.path.exists(mgnify_out_path):
|
| 240 |
+
mgnify_msa = parse_stockholm(load_txt(mgnify_out_path))
|
| 241 |
+
|
| 242 |
+
protein_msas = [uniref90_msa, bfd_uniclust30_msa, bfd_uniref30_msa, reduced_bfd_msa, mgnify_msa]
|
| 243 |
+
protein_msas = [i for i in protein_msas if i is not None]
|
| 244 |
+
|
| 245 |
+
output = dict()
|
| 246 |
+
if len(protein_msas) > 0:
|
| 247 |
+
msa_features = make_msa_features(protein_msas)
|
| 248 |
+
output["msa"] = msa_features.pop("msa")
|
| 249 |
+
output["deletion_matrix"] = msa_features.pop("deletion_matrix")
|
| 250 |
+
output["msa_species_identifiers"] = msa_features.pop("msa_species_identifiers")
|
| 251 |
+
|
| 252 |
+
return output
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def parse_uniprot_alignment_dir(
|
| 256 |
+
alignment_dir,
|
| 257 |
+
):
|
| 258 |
+
uniprot_out_path = os.path.join(alignment_dir, "uniprot_hits.sto")
|
| 259 |
+
uniprot_msa = None
|
| 260 |
+
if os.path.exists(uniprot_out_path):
|
| 261 |
+
uniprot_msa = parse_stockholm(load_txt(uniprot_out_path))
|
| 262 |
+
uniprot_msas = [uniprot_msa]
|
| 263 |
+
uniprot_msas = [i for i in uniprot_msas if i is not None]
|
| 264 |
+
output = dict()
|
| 265 |
+
if len(uniprot_msas) > 0:
|
| 266 |
+
uniprot_msa_features = make_msa_features(uniprot_msas)
|
| 267 |
+
output["msa_all_seq"] = uniprot_msa_features.pop("msa")
|
| 268 |
+
output["deletion_matrix_all_seq"] = uniprot_msa_features.pop("deletion_matrix")
|
| 269 |
+
output["msa_species_identifiers_all_seq"] = uniprot_msa_features.pop("msa_species_identifiers")
|
| 270 |
+
return output
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def parse_rna_from_input_fasta_path(input_fasta_path):
|
| 274 |
+
with open(input_fasta_path, "r") as f:
|
| 275 |
+
query_sequence, dec = parse_fasta(f.read())
|
| 276 |
+
deletion_matrix = [[0] * len(query_sequence[0])]
|
| 277 |
+
|
| 278 |
+
query_msa = Msa(
|
| 279 |
+
sequences=query_sequence,
|
| 280 |
+
deletion_matrix=deletion_matrix,
|
| 281 |
+
descriptions=dec
|
| 282 |
+
)
|
| 283 |
+
return query_msa
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def parse_rna_single_alignment(input_fasta_path):
|
| 287 |
+
query_msa = parse_rna_from_input_fasta_path(input_fasta_path)
|
| 288 |
+
rna_msas = [query_msa]
|
| 289 |
+
msa_features = make_msa_features(rna_msas, is_rna=True)
|
| 290 |
+
output = dict()
|
| 291 |
+
output["msa"] = msa_features.pop("msa")
|
| 292 |
+
output["deletion_matrix"] = msa_features.pop("deletion_matrix")
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def parse_rna_alignment_dir(
|
| 297 |
+
alignment_dir,
|
| 298 |
+
input_fasta_path,
|
| 299 |
+
):
|
| 300 |
+
rfam_out_path = os.path.join(alignment_dir, f"rfam_hits_realigned.sto")
|
| 301 |
+
rnacentral_out_path = os.path.join(alignment_dir, f"rnacentral_hits_realigned.sto")
|
| 302 |
+
nt_out_path = os.path.join(alignment_dir, f"nt_hits_realigned.sto")
|
| 303 |
+
rfam_msa = None
|
| 304 |
+
rnacentral_msa = None
|
| 305 |
+
nt_msa = None
|
| 306 |
+
|
| 307 |
+
if os.path.exists(rfam_out_path):
|
| 308 |
+
# rfam_msa = parse_stockholm(load_txt(rfam_out_path))
|
| 309 |
+
rfam_msa = parse_stockholm_file(rfam_out_path)
|
| 310 |
+
|
| 311 |
+
if os.path.exists(rnacentral_out_path):
|
| 312 |
+
# rnacentral_msa = parse_stockholm(load_txt(rnacentral_out_path))
|
| 313 |
+
rnacentral_msa = parse_stockholm_file(rnacentral_out_path)
|
| 314 |
+
|
| 315 |
+
if os.path.exists(nt_out_path):
|
| 316 |
+
# nt_msa = parse_stockholm(load_txt(nt_out_path))
|
| 317 |
+
nt_msa = parse_stockholm_file(nt_out_path)
|
| 318 |
+
|
| 319 |
+
query_msa = parse_rna_from_input_fasta_path(input_fasta_path)
|
| 320 |
+
rna_msas = [query_msa, rfam_msa, rnacentral_msa, nt_msa]
|
| 321 |
+
|
| 322 |
+
rna_msas = [i for i in rna_msas if i is not None and len(i) > 0]
|
| 323 |
+
# rna_msas_gt0 = [i for i in rna_msas if len(i) > 0]
|
| 324 |
+
output = dict()
|
| 325 |
+
msa_features = make_msa_features(rna_msas, is_rna=True)
|
| 326 |
+
output["msa"] = msa_features.pop("msa")
|
| 327 |
+
output["deletion_matrix"] = msa_features.pop("deletion_matrix")
|
| 328 |
+
return output
|
PhysDock/data/tools/parsers.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Functions for parsing various file formats."""
|
| 17 |
+
import collections
|
| 18 |
+
import dataclasses
|
| 19 |
+
import itertools
|
| 20 |
+
import re
|
| 21 |
+
import string
|
| 22 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
|
| 23 |
+
|
| 24 |
+
DeletionMatrix = Sequence[Sequence[int]]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclasses.dataclass(frozen=True)
|
| 28 |
+
class Msa:
|
| 29 |
+
"""Class representing a parsed MSA file"""
|
| 30 |
+
sequences: Sequence[str]
|
| 31 |
+
deletion_matrix: DeletionMatrix
|
| 32 |
+
descriptions: Optional[Sequence[str]]
|
| 33 |
+
|
| 34 |
+
def __post_init__(self):
|
| 35 |
+
if (not (
|
| 36 |
+
len(self.sequences) ==
|
| 37 |
+
len(self.deletion_matrix) ==
|
| 38 |
+
len(self.descriptions)
|
| 39 |
+
)):
|
| 40 |
+
raise ValueError(
|
| 41 |
+
"All fields for an MSA must have the same length"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return len(self.sequences)
|
| 46 |
+
|
| 47 |
+
def truncate(self, max_seqs: int):
|
| 48 |
+
return Msa(
|
| 49 |
+
sequences=self.sequences[:max_seqs],
|
| 50 |
+
deletion_matrix=self.deletion_matrix[:max_seqs],
|
| 51 |
+
descriptions=self.descriptions[:max_seqs],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclasses.dataclass(frozen=True)
|
| 56 |
+
class TemplateHit:
|
| 57 |
+
"""Class representing a template hit."""
|
| 58 |
+
|
| 59 |
+
index: int
|
| 60 |
+
name: str
|
| 61 |
+
aligned_cols: int
|
| 62 |
+
sum_probs: Optional[float]
|
| 63 |
+
query: str
|
| 64 |
+
hit_sequence: str
|
| 65 |
+
indices_query: List[int]
|
| 66 |
+
indices_hit: List[int]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
|
| 70 |
+
"""Parses FASTA string and returns list of strings with amino-acid sequences.
|
| 71 |
+
|
| 72 |
+
Arguments:
|
| 73 |
+
fasta_string: The string contents of a FASTA file.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
A tuple of two lists:
|
| 77 |
+
* A list of sequences.
|
| 78 |
+
* A list of sequence descriptions taken from the comment lines. In the
|
| 79 |
+
same order as the sequences.
|
| 80 |
+
"""
|
| 81 |
+
sequences = []
|
| 82 |
+
descriptions = []
|
| 83 |
+
index = -1
|
| 84 |
+
for line in fasta_string.splitlines():
|
| 85 |
+
line = line.strip()
|
| 86 |
+
if line.startswith(">"):
|
| 87 |
+
index += 1
|
| 88 |
+
descriptions.append(line[1:]) # Remove the '>' at the beginning.
|
| 89 |
+
sequences.append("")
|
| 90 |
+
continue
|
| 91 |
+
elif line.startswith("#"):
|
| 92 |
+
continue
|
| 93 |
+
elif not line:
|
| 94 |
+
continue # Skip blank lines.
|
| 95 |
+
sequences[index] += line
|
| 96 |
+
|
| 97 |
+
return sequences, descriptions
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def parse_stockholm(stockholm_string: str) -> Msa:
|
| 103 |
+
"""Parses sequences and deletion matrix from stockholm format alignment.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
stockholm_string: The string contents of a stockholm file. The first
|
| 107 |
+
sequence in the file should be the query sequence.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
A tuple of:
|
| 111 |
+
* A list of sequences that have been aligned to the query. These
|
| 112 |
+
might contain duplicates.
|
| 113 |
+
* The deletion matrix for the alignment as a list of lists. The element
|
| 114 |
+
at `deletion_matrix[i][j]` is the number of residues deleted from
|
| 115 |
+
the aligned sequence i at residue position j.
|
| 116 |
+
* The names of the targets matched, including the jackhmmer subsequence
|
| 117 |
+
suffix.
|
| 118 |
+
"""
|
| 119 |
+
name_to_sequence = collections.OrderedDict()
|
| 120 |
+
for line in stockholm_string.splitlines():
|
| 121 |
+
line = line.strip()
|
| 122 |
+
if not line or line.startswith(("#", "//")):
|
| 123 |
+
continue
|
| 124 |
+
name, sequence = line.split()
|
| 125 |
+
if name not in name_to_sequence:
|
| 126 |
+
name_to_sequence[name] = ""
|
| 127 |
+
name_to_sequence[name] += sequence
|
| 128 |
+
|
| 129 |
+
msa = []
|
| 130 |
+
deletion_matrix = []
|
| 131 |
+
|
| 132 |
+
query = ""
|
| 133 |
+
keep_columns = []
|
| 134 |
+
for seq_index, sequence in enumerate(name_to_sequence.values()):
|
| 135 |
+
if seq_index == 0:
|
| 136 |
+
# Gather the columns with gaps from the query
|
| 137 |
+
query = sequence
|
| 138 |
+
keep_columns = [i for i, res in enumerate(query) if res != "-"]
|
| 139 |
+
|
| 140 |
+
# Remove the columns with gaps in the query from all sequences.
|
| 141 |
+
aligned_sequence = "".join([sequence[c] for c in keep_columns])
|
| 142 |
+
|
| 143 |
+
msa.append(aligned_sequence)
|
| 144 |
+
|
| 145 |
+
# Count the number of deletions w.r.t. query.
|
| 146 |
+
deletion_vec = []
|
| 147 |
+
deletion_count = 0
|
| 148 |
+
for seq_res, query_res in zip(sequence, query):
|
| 149 |
+
if seq_res != "-" or query_res != "-":
|
| 150 |
+
if query_res == "-":
|
| 151 |
+
deletion_count += 1
|
| 152 |
+
else:
|
| 153 |
+
deletion_vec.append(deletion_count)
|
| 154 |
+
deletion_count = 0
|
| 155 |
+
deletion_matrix.append(deletion_vec)
|
| 156 |
+
|
| 157 |
+
return Msa(
|
| 158 |
+
sequences=msa,
|
| 159 |
+
deletion_matrix=deletion_matrix,
|
| 160 |
+
descriptions=list(name_to_sequence.keys())
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def parse_stockholm_file(stockholm_file: str) -> Msa:
|
| 165 |
+
"""Parses sequences and deletion matrix from stockholm format alignment.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
stockholm_string: The string contents of a stockholm file. The first
|
| 169 |
+
sequence in the file should be the query sequence.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
A tuple of:
|
| 173 |
+
* A list of sequences that have been aligned to the query. These
|
| 174 |
+
might contain duplicates.
|
| 175 |
+
* The deletion matrix for the alignment as a list of lists. The element
|
| 176 |
+
at `deletion_matrix[i][j]` is the number of residues deleted from
|
| 177 |
+
the aligned sequence i at residue position j.
|
| 178 |
+
* The names of the targets matched, including the jackhmmer subsequence
|
| 179 |
+
suffix.
|
| 180 |
+
"""
|
| 181 |
+
name_to_sequence = collections.OrderedDict()
|
| 182 |
+
with open(stockholm_file, "r") as f:
|
| 183 |
+
for line in f:
|
| 184 |
+
line = line.strip()
|
| 185 |
+
|
| 186 |
+
if not line or line.startswith(("#", "//")):
|
| 187 |
+
continue
|
| 188 |
+
name, sequence = line.split()
|
| 189 |
+
sequence = "".join([c for c in sequence if not c.islower() and c not in ["*","."]])
|
| 190 |
+
if name not in name_to_sequence:
|
| 191 |
+
name_to_sequence[name] = ""
|
| 192 |
+
name_to_sequence[name] += sequence
|
| 193 |
+
|
| 194 |
+
msa = []
|
| 195 |
+
deletion_matrix = []
|
| 196 |
+
|
| 197 |
+
query = ""
|
| 198 |
+
keep_columns = []
|
| 199 |
+
for seq_index, sequence in enumerate(name_to_sequence.values()):
|
| 200 |
+
if seq_index == 0:
|
| 201 |
+
# Gather the columns with gaps from the query
|
| 202 |
+
query = sequence
|
| 203 |
+
keep_columns = [i for i, res in enumerate(query) if res != "-"]
|
| 204 |
+
|
| 205 |
+
# Remove the columns with gaps in the query from all sequences.
|
| 206 |
+
aligned_sequence = "".join([sequence[c] for c in keep_columns])
|
| 207 |
+
|
| 208 |
+
msa.append(aligned_sequence)
|
| 209 |
+
|
| 210 |
+
# Count the number of deletions w.r.t. query.
|
| 211 |
+
deletion_vec = []
|
| 212 |
+
deletion_count = 0
|
| 213 |
+
for seq_res, query_res in zip(sequence, query):
|
| 214 |
+
if seq_res != "-" or query_res != "-":
|
| 215 |
+
if query_res == "-":
|
| 216 |
+
deletion_count += 1
|
| 217 |
+
else:
|
| 218 |
+
deletion_vec.append(deletion_count)
|
| 219 |
+
deletion_count = 0
|
| 220 |
+
deletion_matrix.append(deletion_vec)
|
| 221 |
+
return Msa(
|
| 222 |
+
sequences=msa,
|
| 223 |
+
deletion_matrix=deletion_matrix,
|
| 224 |
+
descriptions=list(name_to_sequence.keys())
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def parse_a3m(a3m_string: str) -> Msa:
|
| 229 |
+
"""Parses sequences and deletion matrix from a3m format alignment.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
a3m_string: The string contents of a a3m file. The first sequence in the
|
| 233 |
+
file should be the query sequence.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
A tuple of:
|
| 237 |
+
* A list of sequences that have been aligned to the query. These
|
| 238 |
+
might contain duplicates.
|
| 239 |
+
* The deletion matrix for the alignment as a list of lists. The element
|
| 240 |
+
at `deletion_matrix[i][j]` is the number of residues deleted from
|
| 241 |
+
the aligned sequence i at residue position j.
|
| 242 |
+
"""
|
| 243 |
+
sequences, descriptions = parse_fasta(a3m_string)
|
| 244 |
+
deletion_matrix = []
|
| 245 |
+
for msa_sequence in sequences:
|
| 246 |
+
deletion_vec = []
|
| 247 |
+
deletion_count = 0
|
| 248 |
+
for j in msa_sequence:
|
| 249 |
+
if j.islower():
|
| 250 |
+
deletion_count += 1
|
| 251 |
+
else:
|
| 252 |
+
deletion_vec.append(deletion_count)
|
| 253 |
+
deletion_count = 0
|
| 254 |
+
deletion_matrix.append(deletion_vec)
|
| 255 |
+
|
| 256 |
+
# Make the MSA matrix out of aligned (deletion-free) sequences.
|
| 257 |
+
deletion_table = str.maketrans("", "", string.ascii_lowercase)
|
| 258 |
+
aligned_sequences = [s.translate(deletion_table) for s in sequences]
|
| 259 |
+
return Msa(
|
| 260 |
+
sequences=aligned_sequences,
|
| 261 |
+
deletion_matrix=deletion_matrix,
|
| 262 |
+
descriptions=descriptions
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _convert_sto_seq_to_a3m(
|
| 267 |
+
query_non_gaps: Sequence[bool], sto_seq: str
|
| 268 |
+
) -> Iterable[str]:
|
| 269 |
+
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
|
| 270 |
+
if is_query_res_non_gap:
|
| 271 |
+
yield sequence_res
|
| 272 |
+
elif sequence_res != "-":
|
| 273 |
+
yield sequence_res.lower()
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def convert_stockholm_to_a3m(
|
| 277 |
+
stockholm_format: str,
|
| 278 |
+
max_sequences: Optional[int] = None,
|
| 279 |
+
remove_first_row_gaps: bool = True,
|
| 280 |
+
) -> str:
|
| 281 |
+
"""Converts MSA in Stockholm format to the A3M format."""
|
| 282 |
+
descriptions = {}
|
| 283 |
+
sequences = {}
|
| 284 |
+
reached_max_sequences = False
|
| 285 |
+
|
| 286 |
+
for line in stockholm_format.splitlines():
|
| 287 |
+
reached_max_sequences = (
|
| 288 |
+
max_sequences and len(sequences) >= max_sequences
|
| 289 |
+
)
|
| 290 |
+
if line.strip() and not line.startswith(("#", "//")):
|
| 291 |
+
# Ignore blank lines, markup and end symbols - remainder are alignment
|
| 292 |
+
# sequence parts.
|
| 293 |
+
seqname, aligned_seq = line.split(maxsplit=1)
|
| 294 |
+
if seqname not in sequences:
|
| 295 |
+
if reached_max_sequences:
|
| 296 |
+
continue
|
| 297 |
+
sequences[seqname] = ""
|
| 298 |
+
sequences[seqname] += aligned_seq
|
| 299 |
+
|
| 300 |
+
for line in stockholm_format.splitlines():
|
| 301 |
+
if line[:4] == "#=GS":
|
| 302 |
+
# Description row - example format is:
|
| 303 |
+
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
|
| 304 |
+
columns = line.split(maxsplit=3)
|
| 305 |
+
seqname, feature = columns[1:3]
|
| 306 |
+
value = columns[3] if len(columns) == 4 else ""
|
| 307 |
+
if feature != "DE":
|
| 308 |
+
continue
|
| 309 |
+
if reached_max_sequences and seqname not in sequences:
|
| 310 |
+
continue
|
| 311 |
+
descriptions[seqname] = value
|
| 312 |
+
if len(descriptions) == len(sequences):
|
| 313 |
+
break
|
| 314 |
+
|
| 315 |
+
# Convert sto format to a3m line by line
|
| 316 |
+
a3m_sequences = {}
|
| 317 |
+
if (remove_first_row_gaps):
|
| 318 |
+
# query_sequence is assumed to be the first sequence
|
| 319 |
+
query_sequence = next(iter(sequences.values()))
|
| 320 |
+
query_non_gaps = [res != "-" for res in query_sequence]
|
| 321 |
+
for seqname, sto_sequence in sequences.items():
|
| 322 |
+
# Dots are optional in a3m format and are commonly removed.
|
| 323 |
+
out_sequence = sto_sequence.replace('.', '')
|
| 324 |
+
if (remove_first_row_gaps):
|
| 325 |
+
out_sequence = ''.join(
|
| 326 |
+
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence)
|
| 327 |
+
)
|
| 328 |
+
a3m_sequences[seqname] = out_sequence
|
| 329 |
+
|
| 330 |
+
fasta_chunks = (
|
| 331 |
+
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
|
| 332 |
+
for k in a3m_sequences
|
| 333 |
+
)
|
| 334 |
+
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _keep_line(line: str, seqnames: Set[str]) -> bool:
|
| 338 |
+
"""Function to decide which lines to keep."""
|
| 339 |
+
if not line.strip():
|
| 340 |
+
return True
|
| 341 |
+
if line.strip() == '//': # End tag
|
| 342 |
+
return True
|
| 343 |
+
if line.startswith('# STOCKHOLM'): # Start tag
|
| 344 |
+
return True
|
| 345 |
+
if line.startswith('#=GC RF'): # Reference Annotation Line
|
| 346 |
+
return True
|
| 347 |
+
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
|
| 348 |
+
_, seqname, _ = line.split(maxsplit=2)
|
| 349 |
+
return seqname in seqnames
|
| 350 |
+
elif line.startswith('#'): # Other markup - filter out
|
| 351 |
+
return False
|
| 352 |
+
else: # Alignment data - keep if sequence in list.
|
| 353 |
+
seqname = line.partition(' ')[0]
|
| 354 |
+
return seqname in seqnames
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
|
| 358 |
+
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
|
| 359 |
+
seqnames = set()
|
| 360 |
+
filtered_lines = []
|
| 361 |
+
|
| 362 |
+
with open(stockholm_msa_path) as f:
|
| 363 |
+
for line in f:
|
| 364 |
+
if line.strip() and not line.startswith(('#', '//')):
|
| 365 |
+
# Ignore blank lines, markup and end symbols - remainder are alignment
|
| 366 |
+
# sequence parts.
|
| 367 |
+
seqname = line.partition(' ')[0]
|
| 368 |
+
seqnames.add(seqname)
|
| 369 |
+
if len(seqnames) >= max_sequences:
|
| 370 |
+
break
|
| 371 |
+
|
| 372 |
+
f.seek(0)
|
| 373 |
+
for line in f:
|
| 374 |
+
if _keep_line(line, seqnames):
|
| 375 |
+
filtered_lines.append(line)
|
| 376 |
+
|
| 377 |
+
return ''.join(filtered_lines)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
|
| 381 |
+
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
|
| 382 |
+
processed_lines = {}
|
| 383 |
+
unprocessed_lines = {}
|
| 384 |
+
for i, line in enumerate(stockholm_msa.splitlines()):
|
| 385 |
+
if line.startswith('#=GC RF'):
|
| 386 |
+
reference_annotation_i = i
|
| 387 |
+
reference_annotation_line = line
|
| 388 |
+
# Reached the end of this chunk of the alignment. Process chunk.
|
| 389 |
+
_, _, first_alignment = line.rpartition(' ')
|
| 390 |
+
mask = []
|
| 391 |
+
for j in range(len(first_alignment)):
|
| 392 |
+
for _, unprocessed_line in unprocessed_lines.items():
|
| 393 |
+
prefix, _, alignment = unprocessed_line.rpartition(' ')
|
| 394 |
+
if alignment[j] != '-':
|
| 395 |
+
mask.append(True)
|
| 396 |
+
break
|
| 397 |
+
else: # Every row contained a hyphen - empty column.
|
| 398 |
+
mask.append(False)
|
| 399 |
+
# Add reference annotation for processing with mask.
|
| 400 |
+
unprocessed_lines[reference_annotation_i] = reference_annotation_line
|
| 401 |
+
|
| 402 |
+
if not any(mask): # All columns were empty. Output empty lines for chunk.
|
| 403 |
+
for line_index in unprocessed_lines:
|
| 404 |
+
processed_lines[line_index] = ''
|
| 405 |
+
else:
|
| 406 |
+
for line_index, unprocessed_line in unprocessed_lines.items():
|
| 407 |
+
prefix, _, alignment = unprocessed_line.rpartition(' ')
|
| 408 |
+
masked_alignment = ''.join(itertools.compress(alignment, mask))
|
| 409 |
+
processed_lines[line_index] = f'{prefix} {masked_alignment}'
|
| 410 |
+
|
| 411 |
+
# Clear raw_alignments.
|
| 412 |
+
unprocessed_lines = {}
|
| 413 |
+
elif line.strip() and not line.startswith(('#', '//')):
|
| 414 |
+
unprocessed_lines[i] = line
|
| 415 |
+
else:
|
| 416 |
+
processed_lines[i] = line
|
| 417 |
+
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
|
| 421 |
+
"""Remove duplicate sequences (ignoring insertions wrt query)."""
|
| 422 |
+
sequence_dict = collections.defaultdict(str)
|
| 423 |
+
|
| 424 |
+
# First we must extract all sequences from the MSA.
|
| 425 |
+
for line in stockholm_msa.splitlines():
|
| 426 |
+
# Only consider the alignments - ignore reference annotation, empty lines,
|
| 427 |
+
# descriptions or markup.
|
| 428 |
+
if line.strip() and not line.startswith(('#', '//')):
|
| 429 |
+
line = line.strip()
|
| 430 |
+
seqname, alignment = line.split()
|
| 431 |
+
sequence_dict[seqname] += alignment
|
| 432 |
+
|
| 433 |
+
seen_sequences = set()
|
| 434 |
+
seqnames = set()
|
| 435 |
+
# First alignment is the query.
|
| 436 |
+
query_align = next(iter(sequence_dict.values()))
|
| 437 |
+
mask = [c != '-' for c in query_align] # Mask is False for insertions.
|
| 438 |
+
for seqname, alignment in sequence_dict.items():
|
| 439 |
+
# Apply mask to remove all insertions from the string.
|
| 440 |
+
masked_alignment = ''.join(itertools.compress(alignment, mask))
|
| 441 |
+
if masked_alignment in seen_sequences:
|
| 442 |
+
continue
|
| 443 |
+
else:
|
| 444 |
+
seen_sequences.add(masked_alignment)
|
| 445 |
+
seqnames.add(seqname)
|
| 446 |
+
|
| 447 |
+
filtered_lines = []
|
| 448 |
+
for line in stockholm_msa.splitlines():
|
| 449 |
+
if _keep_line(line, seqnames):
|
| 450 |
+
filtered_lines.append(line)
|
| 451 |
+
|
| 452 |
+
return '\n'.join(filtered_lines) + '\n'
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def _get_hhr_line_regex_groups(
|
| 456 |
+
regex_pattern: str, line: str
|
| 457 |
+
) -> Sequence[Optional[str]]:
|
| 458 |
+
match = re.match(regex_pattern, line)
|
| 459 |
+
if match is None:
|
| 460 |
+
raise RuntimeError(f"Could not parse query line {line}")
|
| 461 |
+
return match.groups()
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _update_hhr_residue_indices_list(
|
| 465 |
+
sequence: str, start_index: int, indices_list: List[int]
|
| 466 |
+
):
|
| 467 |
+
"""Computes the relative indices for each residue with respect to the original sequence."""
|
| 468 |
+
counter = start_index
|
| 469 |
+
for symbol in sequence:
|
| 470 |
+
if symbol == "-":
|
| 471 |
+
indices_list.append(-1)
|
| 472 |
+
else:
|
| 473 |
+
indices_list.append(counter)
|
| 474 |
+
counter += 1
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
|
| 478 |
+
"""Parses the detailed HMM HMM comparison section for a single Hit.
|
| 479 |
+
|
| 480 |
+
This works on .hhr files generated from both HHBlits and HHSearch.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
detailed_lines: A list of lines from a single comparison section between 2
|
| 484 |
+
sequences (which each have their own HMM's)
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
A dictionary with the information from that detailed comparison section
|
| 488 |
+
|
| 489 |
+
Raises:
|
| 490 |
+
RuntimeError: If a certain line cannot be processed
|
| 491 |
+
"""
|
| 492 |
+
# Parse first 2 lines.
|
| 493 |
+
number_of_hit = int(detailed_lines[0].split()[-1])
|
| 494 |
+
name_hit = detailed_lines[1][1:]
|
| 495 |
+
|
| 496 |
+
# Parse the summary line.
|
| 497 |
+
pattern = (
|
| 498 |
+
"Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
|
| 499 |
+
" ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
|
| 500 |
+
"]*Template_Neff=(.*)"
|
| 501 |
+
)
|
| 502 |
+
match = re.match(pattern, detailed_lines[2])
|
| 503 |
+
if match is None:
|
| 504 |
+
raise RuntimeError(
|
| 505 |
+
"Could not parse section: %s. Expected this: \n%s to contain summary."
|
| 506 |
+
% (detailed_lines, detailed_lines[2])
|
| 507 |
+
)
|
| 508 |
+
(_, _, _, aligned_cols, _, _, sum_probs, _) = [
|
| 509 |
+
float(x) for x in match.groups()
|
| 510 |
+
]
|
| 511 |
+
|
| 512 |
+
# The next section reads the detailed comparisons. These are in a 'human
|
| 513 |
+
# readable' format which has a fixed length. The strategy employed is to
|
| 514 |
+
# assume that each block starts with the query sequence line, and to parse
|
| 515 |
+
# that with a regexp in order to deduce the fixed length used for that block.
|
| 516 |
+
query = ""
|
| 517 |
+
hit_sequence = ""
|
| 518 |
+
indices_query = []
|
| 519 |
+
indices_hit = []
|
| 520 |
+
length_block = None
|
| 521 |
+
|
| 522 |
+
for line in detailed_lines[3:]:
|
| 523 |
+
# Parse the query sequence line
|
| 524 |
+
if (
|
| 525 |
+
line.startswith("Q ")
|
| 526 |
+
and not line.startswith("Q ss_dssp")
|
| 527 |
+
and not line.startswith("Q ss_pred")
|
| 528 |
+
and not line.startswith("Q Consensus")
|
| 529 |
+
):
|
| 530 |
+
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
|
| 531 |
+
# everything after that.
|
| 532 |
+
# start sequence end total_sequence_length
|
| 533 |
+
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
|
| 534 |
+
groups = _get_hhr_line_regex_groups(patt, line[17:])
|
| 535 |
+
|
| 536 |
+
# Get the length of the parsed block using the start and finish indices,
|
| 537 |
+
# and ensure it is the same as the actual block length.
|
| 538 |
+
start = int(groups[0]) - 1 # Make index zero based.
|
| 539 |
+
delta_query = groups[1]
|
| 540 |
+
end = int(groups[2])
|
| 541 |
+
num_insertions = len([x for x in delta_query if x == "-"])
|
| 542 |
+
length_block = end - start + num_insertions
|
| 543 |
+
assert length_block == len(delta_query)
|
| 544 |
+
|
| 545 |
+
# Update the query sequence and indices list.
|
| 546 |
+
query += delta_query
|
| 547 |
+
_update_hhr_residue_indices_list(delta_query, start, indices_query)
|
| 548 |
+
|
| 549 |
+
elif line.startswith("T "):
|
| 550 |
+
# Parse the hit sequence.
|
| 551 |
+
if (
|
| 552 |
+
not line.startswith("T ss_dssp")
|
| 553 |
+
and not line.startswith("T ss_pred")
|
| 554 |
+
and not line.startswith("T Consensus")
|
| 555 |
+
):
|
| 556 |
+
# Thus the first 17 characters must be 'T <hit_name> ', and we can
|
| 557 |
+
# parse everything after that.
|
| 558 |
+
# start sequence end total_sequence_length
|
| 559 |
+
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
|
| 560 |
+
groups = _get_hhr_line_regex_groups(patt, line[17:])
|
| 561 |
+
start = int(groups[0]) - 1 # Make index zero based.
|
| 562 |
+
delta_hit_sequence = groups[1]
|
| 563 |
+
assert length_block == len(delta_hit_sequence)
|
| 564 |
+
|
| 565 |
+
# Update the hit sequence and indices list.
|
| 566 |
+
hit_sequence += delta_hit_sequence
|
| 567 |
+
_update_hhr_residue_indices_list(
|
| 568 |
+
delta_hit_sequence, start, indices_hit
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
return TemplateHit(
|
| 572 |
+
index=number_of_hit,
|
| 573 |
+
name=name_hit,
|
| 574 |
+
aligned_cols=int(aligned_cols),
|
| 575 |
+
sum_probs=sum_probs,
|
| 576 |
+
query=query,
|
| 577 |
+
hit_sequence=hit_sequence,
|
| 578 |
+
indices_query=indices_query,
|
| 579 |
+
indices_hit=indices_hit,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
|
| 584 |
+
"""Parses the content of an entire HHR file."""
|
| 585 |
+
lines = hhr_string.splitlines()
|
| 586 |
+
|
| 587 |
+
# Each .hhr file starts with a results table, then has a sequence of hit
|
| 588 |
+
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
|
| 589 |
+
# iterate through each paragraph to parse each hit.
|
| 590 |
+
|
| 591 |
+
block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
|
| 592 |
+
|
| 593 |
+
hits = []
|
| 594 |
+
if block_starts:
|
| 595 |
+
block_starts.append(len(lines)) # Add the end of the final block.
|
| 596 |
+
for i in range(len(block_starts) - 1):
|
| 597 |
+
hits.append(
|
| 598 |
+
_parse_hhr_hit(lines[block_starts[i]: block_starts[i + 1]])
|
| 599 |
+
)
|
| 600 |
+
return hits
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
|
| 604 |
+
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
|
| 605 |
+
e_values = {"query": 0}
|
| 606 |
+
lines = [line for line in tblout.splitlines() if line[0] != "#"]
|
| 607 |
+
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
|
| 608 |
+
# space-delimited. Relevant fields are (1) target name: and
|
| 609 |
+
# (5) E-value (full sequence) (numbering from 1).
|
| 610 |
+
for line in lines:
|
| 611 |
+
fields = line.split()
|
| 612 |
+
e_value = fields[4]
|
| 613 |
+
target_name = fields[0]
|
| 614 |
+
e_values[target_name] = float(e_value)
|
| 615 |
+
return e_values
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def _get_indices(sequence: str, start: int) -> List[int]:
|
| 619 |
+
"""Returns indices for non-gap/insert residues starting at the given index."""
|
| 620 |
+
indices = []
|
| 621 |
+
counter = start
|
| 622 |
+
for symbol in sequence:
|
| 623 |
+
# Skip gaps but add a placeholder so that the alignment is preserved.
|
| 624 |
+
if symbol == '-':
|
| 625 |
+
indices.append(-1)
|
| 626 |
+
# Skip deleted residues, but increase the counter.
|
| 627 |
+
elif symbol.islower():
|
| 628 |
+
counter += 1
|
| 629 |
+
# Normal aligned residue. Increase the counter and append to indices.
|
| 630 |
+
else:
|
| 631 |
+
indices.append(counter)
|
| 632 |
+
counter += 1
|
| 633 |
+
return indices
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
@dataclasses.dataclass(frozen=True)
|
| 637 |
+
class HitMetadata:
|
| 638 |
+
pdb_id: str
|
| 639 |
+
chain: str
|
| 640 |
+
start: int
|
| 641 |
+
end: int
|
| 642 |
+
length: int
|
| 643 |
+
text: str
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def _parse_hmmsearch_description(description: str) -> HitMetadata:
|
| 647 |
+
"""Parses the hmmsearch A3M sequence description line."""
|
| 648 |
+
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
|
| 649 |
+
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
|
| 650 |
+
match = re.match(
|
| 651 |
+
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
|
| 652 |
+
description.strip())
|
| 653 |
+
|
| 654 |
+
if not match:
|
| 655 |
+
raise ValueError(f'Could not parse description: "{description}".')
|
| 656 |
+
|
| 657 |
+
return HitMetadata(
|
| 658 |
+
pdb_id=match[1],
|
| 659 |
+
chain=match[2],
|
| 660 |
+
start=int(match[3]),
|
| 661 |
+
end=int(match[4]),
|
| 662 |
+
length=int(match[5]),
|
| 663 |
+
text=match[6]
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def parse_hmmsearch_a3m(
|
| 668 |
+
query_sequence: str,
|
| 669 |
+
a3m_string: str,
|
| 670 |
+
skip_first: bool = True
|
| 671 |
+
) -> Sequence[TemplateHit]:
|
| 672 |
+
"""Parses an a3m string produced by hmmsearch.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
query_sequence: The query sequence.
|
| 676 |
+
a3m_string: The a3m string produced by hmmsearch.
|
| 677 |
+
skip_first: Whether to skip the first sequence in the a3m string.
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
A sequence of `TemplateHit` results.
|
| 681 |
+
"""
|
| 682 |
+
# Zip the descriptions and MSAs together, skip the first query sequence.
|
| 683 |
+
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
|
| 684 |
+
if skip_first:
|
| 685 |
+
parsed_a3m = parsed_a3m[1:]
|
| 686 |
+
|
| 687 |
+
indices_query = _get_indices(query_sequence, start=0)
|
| 688 |
+
|
| 689 |
+
hits = []
|
| 690 |
+
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
|
| 691 |
+
if 'mol:protein' not in hit_description:
|
| 692 |
+
continue # Skip non-protein chains.
|
| 693 |
+
metadata = _parse_hmmsearch_description(hit_description)
|
| 694 |
+
# Aligned columns are only the match states.
|
| 695 |
+
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
|
| 696 |
+
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
|
| 697 |
+
|
| 698 |
+
hit = TemplateHit(
|
| 699 |
+
index=i,
|
| 700 |
+
name=f'{metadata.pdb_id}_{metadata.chain}',
|
| 701 |
+
aligned_cols=aligned_cols,
|
| 702 |
+
sum_probs=None,
|
| 703 |
+
query=query_sequence,
|
| 704 |
+
hit_sequence=hit_sequence.upper(),
|
| 705 |
+
indices_query=indices_query,
|
| 706 |
+
indices_hit=indices_hit,
|
| 707 |
+
)
|
| 708 |
+
hits.append(hit)
|
| 709 |
+
|
| 710 |
+
return hits
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def parse_hmmsearch_sto(
|
| 714 |
+
output_string: str,
|
| 715 |
+
input_sequence: str
|
| 716 |
+
) -> Sequence[TemplateHit]:
|
| 717 |
+
"""Gets parsed template hits from the raw string output by the tool."""
|
| 718 |
+
a3m_string = convert_stockholm_to_a3m(
|
| 719 |
+
output_string,
|
| 720 |
+
remove_first_row_gaps=False
|
| 721 |
+
)
|
| 722 |
+
template_hits = parse_hmmsearch_a3m(
|
| 723 |
+
query_sequence=input_sequence,
|
| 724 |
+
a3m_string=a3m_string,
|
| 725 |
+
skip_first=False
|
| 726 |
+
)
|
| 727 |
+
return template_hits
|
PhysDock/data/tools/rdkit.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import rdkit
|
| 3 |
+
from rdkit import Chem
|
| 4 |
+
from rdkit.Chem import AllChem, rdmolops
|
| 5 |
+
from rdkit.Chem.rdchem import ChiralType, BondType
|
| 6 |
+
import numpy as np
|
| 7 |
+
import copy
|
| 8 |
+
|
| 9 |
+
from PhysDock.data.constants.periodic_table import PeriodicTable
|
| 10 |
+
from PhysDock.utils.io_utils import load_txt
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_ref_mol(string):
|
| 14 |
+
if Chem.MolFromSmiles(string) is not None:
|
| 15 |
+
mol = Chem.MolFromSmiles(string)
|
| 16 |
+
elif os.path.isfile(string) and string.split(".")[-1] == "smi":
|
| 17 |
+
mol = Chem.MolFromSmiles(load_txt(string).strip())
|
| 18 |
+
else:
|
| 19 |
+
mol = None
|
| 20 |
+
if mol is not None:
|
| 21 |
+
AllChem.EmbedMolecule(mol,maxAttempts=100000)
|
| 22 |
+
# mol2 = Chem.MolFromPDBBlock(Chem.MolToPDBBlock(mol))
|
| 23 |
+
# for atom in mol2.GetAtoms():
|
| 24 |
+
# # if atom.GetChiralTag() != ChiralType.CHI_UNSPECIFIED:
|
| 25 |
+
# print(f"Atom {atom.GetIdx()} has chiral tag: {atom.GetChiralTag()}")
|
| 26 |
+
# mol = Chem.RemoveAllHs(mol2)
|
| 27 |
+
mol = Chem.RemoveAllHs(mol)
|
| 28 |
+
return mol
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Hybridization = {
|
| 32 |
+
Chem.rdchem.HybridizationType.S: 0,
|
| 33 |
+
Chem.rdchem.HybridizationType.SP: 1,
|
| 34 |
+
Chem.rdchem.HybridizationType.SP2: 2,
|
| 35 |
+
Chem.rdchem.HybridizationType.SP3: 3,
|
| 36 |
+
Chem.rdchem.HybridizationType.SP3D: 4,
|
| 37 |
+
Chem.rdchem.HybridizationType.SP3D2: 5,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
Chirality = {ChiralType.CHI_TETRAHEDRAL_CW: 0,
|
| 41 |
+
ChiralType.CHI_TETRAHEDRAL_CCW: 1,
|
| 42 |
+
ChiralType.CHI_UNSPECIFIED: 2,
|
| 43 |
+
ChiralType.CHI_OTHER: 2}
|
| 44 |
+
# Add None
|
| 45 |
+
Bonds = {BondType.SINGLE: 0, BondType.DOUBLE: 1, BondType.TRIPLE: 2, BondType.AROMATIC: 3}
|
| 46 |
+
|
| 47 |
+
dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]')
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Feats From SMI
|
| 51 |
+
# Feats From MOL
|
| 52 |
+
# Feats From SDF
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_features_from_ref_mol(
|
| 56 |
+
ref_mol,
|
| 57 |
+
remove_hs=True
|
| 58 |
+
):
|
| 59 |
+
if remove_hs:
|
| 60 |
+
ref_mol = Chem.RemoveAllHs(ref_mol)
|
| 61 |
+
# print(ref_mol)
|
| 62 |
+
# if ref_mol.GetNumConformers()==0:
|
| 63 |
+
# AllChem.EmbedMolecule(ref_mol,useExpTorsionAnglePrefs=True, useBasicKnowledge=True,maxAttempts=100000)
|
| 64 |
+
ref_conf = ref_mol.GetConformer()
|
| 65 |
+
x_gt = []
|
| 66 |
+
for atom_id, atom in enumerate(ref_mol.GetAtoms()):
|
| 67 |
+
atom_pos = ref_conf.GetAtomPosition(atom_id)
|
| 68 |
+
x_gt.append(np.array([atom_pos.x, atom_pos.y, atom_pos.z]))
|
| 69 |
+
x_gt = np.stack(x_gt, axis=0).astype(np.float32)
|
| 70 |
+
x_exists = np.ones_like(x_gt[:, 0]).astype(np.int64)
|
| 71 |
+
a_mask = np.ones_like(x_gt[:, 0]).astype(np.int64)
|
| 72 |
+
|
| 73 |
+
# Ref Mol
|
| 74 |
+
AllChem.EmbedMolecule(ref_mol,maxAttempts=100000)
|
| 75 |
+
AllChem.MMFFOptimizeMolecule(ref_mol)
|
| 76 |
+
num_atoms = ref_mol.GetNumAtoms()
|
| 77 |
+
conf = ref_mol.GetConformer()
|
| 78 |
+
ring = ref_mol.GetRingInfo()
|
| 79 |
+
|
| 80 |
+
# Filtering Conditions
|
| 81 |
+
# if ref_mol.GetNumAtoms() < 4:
|
| 82 |
+
# return None
|
| 83 |
+
# if ref_mol.GetNumBonds() < 4:
|
| 84 |
+
# return None
|
| 85 |
+
#
|
| 86 |
+
# k = 0
|
| 87 |
+
# for conf in [conf]:
|
| 88 |
+
# # skip mols with atoms with more than 4 neighbors for now
|
| 89 |
+
# n_neighbors = [len(a.GetNeighbors()) for a in ref_mol.GetAtoms()]
|
| 90 |
+
# if np.max(n_neighbors) > 4:
|
| 91 |
+
# continue
|
| 92 |
+
# try:
|
| 93 |
+
# conf_canonical_smi = Chem.MolToSmiles(Chem.RemoveHs(ref_mol))
|
| 94 |
+
# except Exception as e:
|
| 95 |
+
# continue
|
| 96 |
+
# k += 1
|
| 97 |
+
# if k == 0:
|
| 98 |
+
# return None
|
| 99 |
+
|
| 100 |
+
ref_pos = []
|
| 101 |
+
ref_charge = []
|
| 102 |
+
ref_element = []
|
| 103 |
+
ref_is_aromatic = []
|
| 104 |
+
ref_degree = []
|
| 105 |
+
ref_hybridization = []
|
| 106 |
+
ref_implicit_valence = []
|
| 107 |
+
ref_chirality = []
|
| 108 |
+
ref_in_ring_of_3 = []
|
| 109 |
+
ref_in_ring_of_4 = []
|
| 110 |
+
ref_in_ring_of_5 = []
|
| 111 |
+
ref_in_ring_of_6 = []
|
| 112 |
+
ref_in_ring_of_7 = []
|
| 113 |
+
ref_in_ring_of_8 = []
|
| 114 |
+
for atom_id, atom in enumerate(ref_mol.GetAtoms()):
|
| 115 |
+
atom_pos = conf.GetAtomPosition(atom_id)
|
| 116 |
+
ref_pos.append(np.array([atom_pos.x, atom_pos.y, atom_pos.z]))
|
| 117 |
+
ref_charge.append(atom.GetFormalCharge())
|
| 118 |
+
ref_element.append(atom.GetAtomicNum() - 1)
|
| 119 |
+
ref_is_aromatic.append(int(atom.GetIsAromatic()))
|
| 120 |
+
ref_degree.append(min(atom.GetDegree(), 8))
|
| 121 |
+
ref_hybridization.append(Hybridization.get(atom.GetHybridization(), 6))
|
| 122 |
+
ref_implicit_valence.append(min(atom.GetImplicitValence(), 8))
|
| 123 |
+
ref_chirality.append(Chirality.get(atom.GetChiralTag(), 2))
|
| 124 |
+
ref_in_ring_of_3.append(int(ring.IsAtomInRingOfSize(atom_id, 3)))
|
| 125 |
+
ref_in_ring_of_4.append(int(ring.IsAtomInRingOfSize(atom_id, 4)))
|
| 126 |
+
ref_in_ring_of_5.append(int(ring.IsAtomInRingOfSize(atom_id, 5)))
|
| 127 |
+
ref_in_ring_of_6.append(int(ring.IsAtomInRingOfSize(atom_id, 6)))
|
| 128 |
+
ref_in_ring_of_7.append(int(ring.IsAtomInRingOfSize(atom_id, 7)))
|
| 129 |
+
ref_in_ring_of_8.append(int(ring.IsAtomInRingOfSize(atom_id, 8)))
|
| 130 |
+
|
| 131 |
+
ref_pos = np.stack(ref_pos, axis=0).astype(np.float32)
|
| 132 |
+
ref_charge = np.array(ref_charge).astype(np.float32)
|
| 133 |
+
ref_element = np.array(ref_element).astype(np.int8)
|
| 134 |
+
ref_is_aromatic = np.array(ref_is_aromatic).astype(np.int8)
|
| 135 |
+
ref_degree = np.array(ref_degree).astype(np.int8)
|
| 136 |
+
ref_hybridization = np.array(ref_hybridization).astype(np.int8)
|
| 137 |
+
ref_implicit_valence = np.array(ref_implicit_valence).astype(np.int8)
|
| 138 |
+
ref_chirality = np.array(ref_chirality).astype(np.int8)
|
| 139 |
+
ref_in_ring_of_3 = np.array(ref_in_ring_of_3).astype(np.int8)
|
| 140 |
+
ref_in_ring_of_4 = np.array(ref_in_ring_of_4).astype(np.int8)
|
| 141 |
+
ref_in_ring_of_5 = np.array(ref_in_ring_of_5).astype(np.int8)
|
| 142 |
+
ref_in_ring_of_6 = np.array(ref_in_ring_of_6).astype(np.int8)
|
| 143 |
+
ref_in_ring_of_7 = np.array(ref_in_ring_of_7).astype(np.int8)
|
| 144 |
+
ref_in_ring_of_8 = np.array(ref_in_ring_of_8).astype(np.int8)
|
| 145 |
+
|
| 146 |
+
d_token = np.zeros([num_atoms, num_atoms], dtype=np.int8)
|
| 147 |
+
token_bonds = np.zeros([num_atoms, num_atoms], dtype=np.int8)
|
| 148 |
+
bond_type = np.zeros([num_atoms, num_atoms], dtype=np.int8)
|
| 149 |
+
bond_as_double = np.zeros([num_atoms, num_atoms], dtype=np.int8)
|
| 150 |
+
bond_in_ring = np.zeros([num_atoms, num_atoms], dtype=np.int8)
|
| 151 |
+
bond_is_aromatic = np.zeros([num_atoms, num_atoms], dtype=np.int8)
|
| 152 |
+
bond_is_conjugated = np.zeros([num_atoms, num_atoms], dtype=np.int8)
|
| 153 |
+
for i in range(num_atoms - 1):
|
| 154 |
+
for j in range(i + 1, num_atoms):
|
| 155 |
+
dist = len(rdmolops.GetShortestPath(ref_mol, i, j)) - 1
|
| 156 |
+
dist = min(30, dist)
|
| 157 |
+
d_token[i, j] = dist
|
| 158 |
+
d_token[j, i] = dist
|
| 159 |
+
for bond_id, bond in enumerate(ref_mol.GetBonds()):
|
| 160 |
+
i = bond.GetBeginAtomIdx()
|
| 161 |
+
j = bond.GetEndAtomIdx()
|
| 162 |
+
token_bonds[i, j] = 1
|
| 163 |
+
token_bonds[j, i] = 1
|
| 164 |
+
bond_type[i, j] = Bonds.get(bond.GetBondType(), 4)
|
| 165 |
+
bond_type[j, i] = Bonds.get(bond.GetBondType(), 4)
|
| 166 |
+
bond_as_double[i, j] = bond.GetBondTypeAsDouble()
|
| 167 |
+
bond_as_double[j, i] = bond.GetBondTypeAsDouble()
|
| 168 |
+
bond_in_ring[i, j] = bond.IsInRing()
|
| 169 |
+
bond_in_ring[j, i] = bond.IsInRing()
|
| 170 |
+
bond_is_conjugated[i, j] = bond.GetIsConjugated()
|
| 171 |
+
bond_is_conjugated[j, i] = bond.GetIsConjugated()
|
| 172 |
+
bond_is_aromatic[i, j] = bond.GetIsAromatic()
|
| 173 |
+
bond_is_aromatic[j, i] = bond.GetIsAromatic()
|
| 174 |
+
|
| 175 |
+
ref_atom_name_chars = [PeriodicTable[e] for e in ref_element.tolist()]
|
| 176 |
+
ref_mask_in_polymer = [1] * len(ref_pos)
|
| 177 |
+
|
| 178 |
+
# ccds, restype, residue_index, atom_id_to_conformer_atom_id, a_mask, x_gt, x_exists
|
| 179 |
+
num_atoms = len(x_gt)
|
| 180 |
+
label_feature = {
|
| 181 |
+
"x_gt": x_gt,
|
| 182 |
+
"x_exists": x_exists,
|
| 183 |
+
"a_mask": a_mask,
|
| 184 |
+
"restype": np.array([20]).astype(np.int64),
|
| 185 |
+
"residue_index": np.arange(1).astype(np.int64),
|
| 186 |
+
"atom_id_to_conformer_atom_id": np.arange(num_atoms).astype(np.int64),
|
| 187 |
+
"conformer_id_to_chunk_sizes": np.array([num_atoms]).astype(np.int64)
|
| 188 |
+
}
|
| 189 |
+
conf_feature = {
|
| 190 |
+
"ref_pos": ref_pos,
|
| 191 |
+
"ref_charge": ref_charge,
|
| 192 |
+
"ref_element": ref_element,
|
| 193 |
+
"ref_is_aromatic": ref_is_aromatic,
|
| 194 |
+
"ref_degree": ref_degree,
|
| 195 |
+
"ref_hybridization": ref_hybridization,
|
| 196 |
+
"ref_implicit_valence": ref_implicit_valence,
|
| 197 |
+
"ref_chirality": ref_chirality,
|
| 198 |
+
"ref_in_ring_of_3": ref_in_ring_of_3,
|
| 199 |
+
"ref_in_ring_of_4": ref_in_ring_of_4,
|
| 200 |
+
"ref_in_ring_of_5": ref_in_ring_of_5,
|
| 201 |
+
"ref_in_ring_of_6": ref_in_ring_of_6,
|
| 202 |
+
"ref_in_ring_of_7": ref_in_ring_of_7,
|
| 203 |
+
"ref_in_ring_of_8": ref_in_ring_of_8,
|
| 204 |
+
"d_token": d_token,
|
| 205 |
+
"token_bonds": token_bonds,
|
| 206 |
+
"bond_type": bond_type,
|
| 207 |
+
"bond_as_double": bond_as_double,
|
| 208 |
+
"bond_in_ring": bond_in_ring,
|
| 209 |
+
"bond_is_conjugated": bond_is_conjugated,
|
| 210 |
+
"bond_is_aromatic": bond_is_aromatic,
|
| 211 |
+
"ref_atom_name_chars": ref_atom_name_chars,
|
| 212 |
+
"ref_mask_in_polymer": ref_mask_in_polymer,
|
| 213 |
+
}
|
| 214 |
+
return label_feature, conf_feature, ref_mol
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def get_features_from_smi(smi, remove_hs=True):
|
| 218 |
+
ref_mol = get_ref_mol(smi)
|
| 219 |
+
label_feature, conf_feature, ref_mol = get_features_from_ref_mol(ref_mol, remove_hs=remove_hs)
|
| 220 |
+
return label_feature, conf_feature, ref_mol
|
PhysDock/data/tools/residue_constants.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
amino_acid_1to3 = {
|
| 4 |
+
"A": "ALA",
|
| 5 |
+
"R": "ARG",
|
| 6 |
+
"N": "ASN",
|
| 7 |
+
"D": "ASP",
|
| 8 |
+
"C": "CYS",
|
| 9 |
+
"Q": "GLN",
|
| 10 |
+
"E": "GLU",
|
| 11 |
+
"G": "GLY",
|
| 12 |
+
"H": "HIS",
|
| 13 |
+
"I": "ILE",
|
| 14 |
+
"L": "LEU",
|
| 15 |
+
"K": "LYS",
|
| 16 |
+
"M": "MET",
|
| 17 |
+
"F": "PHE",
|
| 18 |
+
"P": "PRO",
|
| 19 |
+
"S": "SER",
|
| 20 |
+
"T": "THR",
|
| 21 |
+
"W": "TRP",
|
| 22 |
+
"Y": "TYR",
|
| 23 |
+
"V": "VAL",
|
| 24 |
+
"X": "UNK",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
amino_acid_3to1 = {v: k for k, v in amino_acid_1to3.items()}
|
| 28 |
+
|
| 29 |
+
# Ligand Atom is representaed as "UNK" in token
|
| 30 |
+
# standard_residue is also ccd
|
| 31 |
+
standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
|
| 32 |
+
"LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
|
| 33 |
+
standard_rna = ["A ", "G ", "C ", "U ", "N ", ]
|
| 34 |
+
standard_dna = ["DA ", "DG ", "DC ", "DT ", "DN ", ]
|
| 35 |
+
standard_nucleics = standard_rna + standard_dna
|
| 36 |
+
standard_ccds_without_gap = standard_protein + standard_nucleics
|
| 37 |
+
GAP = ["GAP"] # used in msa one-hot
|
| 38 |
+
standard_ccds = standard_protein + standard_nucleics + GAP
|
| 39 |
+
|
| 40 |
+
standard_ccd_to_order = {ccd: id for id, ccd in enumerate(standard_ccds)}
|
| 41 |
+
|
| 42 |
+
standard_purines = ["A ", "G ", "DA ", "DG "]
|
| 43 |
+
standard_pyrimidines = ["C ", "U ", "DC ", "DT "]
|
| 44 |
+
|
| 45 |
+
is_standard = lambda x: x in standard_ccds
|
| 46 |
+
is_unk = lambda x: x in ["UNK", "N ", "DN ", "GAP", "UNL"]
|
| 47 |
+
is_protein = lambda x: x in standard_protein and not is_unk(x)
|
| 48 |
+
is_rna = lambda x: x in standard_rna and not is_unk(x)
|
| 49 |
+
is_dna = lambda x: x in standard_dna and not is_unk(x)
|
| 50 |
+
is_nucleics = lambda x: x in standard_nucleics and not is_unk(x)
|
| 51 |
+
is_purines = lambda x: x in standard_purines
|
| 52 |
+
is_pyrimidines = lambda x: x in standard_pyrimidines
|
| 53 |
+
|
| 54 |
+
standard_ccd_to_atoms_num = {s: n for s, n in zip(standard_ccds, [
|
| 55 |
+
5, 11, 8, 8, 6, 9, 9, 4, 10, 8,
|
| 56 |
+
8, 9, 8, 11, 7, 6, 7, 14, 12, 7, None,
|
| 57 |
+
22, 23, 20, 20, None,
|
| 58 |
+
21, 22, 19, 20, None,
|
| 59 |
+
None,
|
| 60 |
+
])}
|
| 61 |
+
|
| 62 |
+
standard_ccd_to_token_centre_atom_name = {
|
| 63 |
+
**{residue: "CA" for residue in standard_protein},
|
| 64 |
+
**{residue: "C1'" for residue in standard_nucleics},
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
standard_ccd_to_frame_atom_name_0 = {
|
| 68 |
+
**{residue: "N" for residue in standard_protein},
|
| 69 |
+
**{residue: "C1'" for residue in standard_nucleics},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
standard_ccd_to_frame_atom_name_1 = {
|
| 73 |
+
**{residue: "CA" for residue in standard_protein},
|
| 74 |
+
**{residue: "C3'" for residue in standard_nucleics},
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
standard_ccd_to_frame_atom_name_2 = {
|
| 78 |
+
**{residue: "C" for residue in standard_protein},
|
| 79 |
+
**{residue: "C4'" for residue in standard_nucleics},
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
standard_ccd_to_token_pseudo_beta_atom_name = {
|
| 83 |
+
**{residue: "CB" for residue in standard_protein},
|
| 84 |
+
**{residue: "C4" for residue in standard_purines},
|
| 85 |
+
**{residue: "C2" for residue in standard_pyrimidines},
|
| 86 |
+
}
|
| 87 |
+
standard_ccd_to_token_pseudo_beta_atom_name.update({"GLY": "CA"})
|
| 88 |
+
|
| 89 |
+
HHBLITS_ID_TO_AA = {
|
| 90 |
+
0: "ALA",
|
| 91 |
+
1: "CYS", # Also U.
|
| 92 |
+
2: "ASP", # Also B.
|
| 93 |
+
3: "GLU", # Also Z.
|
| 94 |
+
4: "PHE",
|
| 95 |
+
5: "GLY",
|
| 96 |
+
6: "HIS",
|
| 97 |
+
7: "ILE",
|
| 98 |
+
8: "LYS",
|
| 99 |
+
9: "LEU",
|
| 100 |
+
10: "MET",
|
| 101 |
+
11: "ASN",
|
| 102 |
+
12: "PRO",
|
| 103 |
+
13: "GLN",
|
| 104 |
+
14: "ARG",
|
| 105 |
+
15: "SER",
|
| 106 |
+
16: "THR",
|
| 107 |
+
17: "VAL",
|
| 108 |
+
18: "TRP",
|
| 109 |
+
19: "TYR",
|
| 110 |
+
20: "UNK", # Includes J and O.
|
| 111 |
+
21: "GAP",
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Usage: Convert hhblits msa to af3 aatype
|
| 115 |
+
# msa = hhblits_id_to_standard_residue_id_np[hhblits_msa.astype(np.int64)]
|
| 116 |
+
hhblits_id_to_standard_residue_id_np = np.array(
|
| 117 |
+
[standard_ccds.index(ccd) for id, ccd in HHBLITS_ID_TO_AA.items()]
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
of_restypes = [
|
| 121 |
+
"A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
|
| 122 |
+
"L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X", "-"
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
af3_restypes = [
|
| 126 |
+
amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "-" if ccd == "GAP" else "None"
|
| 127 |
+
for ccd in standard_ccds
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
af3_if_to_residue_id = np.array(
|
| 131 |
+
[af3_restypes.index(restype) if restype in of_restypes else -1 for restype in af3_restypes])
|
| 132 |
+
|
| 133 |
+
########################################################
|
| 134 |
+
# periodic table that used to encode elements #
|
| 135 |
+
########################################################
|
| 136 |
+
periodic_table = [
|
| 137 |
+
"h", "he",
|
| 138 |
+
"li", "be", "b", "c", "n", "o", "f", "ne",
|
| 139 |
+
"na", "mg", "al", "si", "p", "s", "cl", "ar",
|
| 140 |
+
"k", "ca", "sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn", "ga", "ge", "as", "se", "br", "kr",
|
| 141 |
+
"rb", "sr", "y", "zr", "nb", "mo", "tc", "ru", "rh", "pd", "ag", "cd", "in", "sn", "sb", "te", "i", "xe",
|
| 142 |
+
"cs", "ba",
|
| 143 |
+
"la", "ce", "pr", "nd", "pm", "sm", "eu", "gd", "tb", "dy", "ho", "er", "tm", "yb", "lu",
|
| 144 |
+
"hf", "ta", "w", "re", "os", "ir", "pt", "au", "hg", "tl", "pb", "bi", "po", "at", "rn",
|
| 145 |
+
"fr", "ra",
|
| 146 |
+
"ac", "th", "pa", "u", "np", "pu", "am", "cm", "bk", "cf", "es", "fm", "md", "no", "lr",
|
| 147 |
+
"rf", "db", "sg", "bh", "hs", "mt", "ds", "rg", "cn", "nh", "fl", "mc", "lv", "ts", "og"
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
get_element_id = {ele: ele_id for ele_id, ele in enumerate(periodic_table)}
|
| 151 |
+
|
| 152 |
+
##########################################################
|
| 153 |
+
|
| 154 |
+
standard_ccd_to_reference_features_table = {
|
| 155 |
+
# letters_3: [ref_pos,ref_charge, ref_mask, ref_elements, ref_atom_name_chars]
|
| 156 |
+
"ALA": [
|
| 157 |
+
[-0.966, 0.493, 1.500, 0., 1, "N", "N"],
|
| 158 |
+
[0.257, 0.418, 0.692, 0., 1, "C", "CA"],
|
| 159 |
+
[-0.094, 0.017, -0.716, 0., 1, "C", "C"],
|
| 160 |
+
[-1.056, -0.682, -0.923, 0., 1, "O", "O"],
|
| 161 |
+
[1.204, -0.620, 1.296, 0., 1, "C", "CB"],
|
| 162 |
+
[0.661, 0.439, -1.742, 0., 0, "O", "OXT"],
|
| 163 |
+
],
|
| 164 |
+
"ARG": [
|
| 165 |
+
[-0.469, 1.110, -0.993, 0., 1, "N", "N"],
|
| 166 |
+
[0.004, 2.294, -1.708, 0., 1, "C", "CA"],
|
| 167 |
+
[-0.907, 2.521, -2.901, 0., 1, "C", "C"],
|
| 168 |
+
[-1.827, 1.789, -3.242, 0., 1, "O", "O"],
|
| 169 |
+
[1.475, 2.150, -2.127, 0., 1, "C", "CB"],
|
| 170 |
+
[1.745, 1.017, -3.130, 0., 1, "C", "CG"],
|
| 171 |
+
[3.210, 0.954, -3.557, 0., 1, "C", "CD"],
|
| 172 |
+
[4.071, 0.726, -2.421, 0., 1, "N", "NE"],
|
| 173 |
+
[5.469, 0.624, -2.528, 0., 1, "C", "CZ"],
|
| 174 |
+
[6.259, 0.404, -1.405, 0., 1, "N", "NH1"],
|
| 175 |
+
[6.078, 0.744, -3.773, 0., 1, "N", "NH2"],
|
| 176 |
+
[-0.588, 3.659, -3.574, 0., 0, "O", "OXT"],
|
| 177 |
+
],
|
| 178 |
+
"ASN": [
|
| 179 |
+
[-0.293, 1.686, 0.094, 0., 1, "N", "N"],
|
| 180 |
+
[-0.448, 0.292, -0.340, 0., 1, "C", "CA"],
|
| 181 |
+
[-1.846, -0.179, -0.031, 0., 1, "C", "C"],
|
| 182 |
+
[-2.510, 0.402, 0.794, 0., 1, "O", "O"],
|
| 183 |
+
[0.562, -0.588, 0.401, 0., 1, "C", "CB"],
|
| 184 |
+
[1.960, -0.197, -0.002, 0., 1, "C", "CG"],
|
| 185 |
+
[2.132, 0.697, -0.804, 0., 1, "O", "OD1"],
|
| 186 |
+
[3.019, -0.841, 0.527, 0., 1, "N", "ND2"],
|
| 187 |
+
[-2.353, -1.243, -0.673, 0., 0, "O", "OXT"],
|
| 188 |
+
],
|
| 189 |
+
"ASP": [
|
| 190 |
+
[-0.317, 1.688, 0.066, 0., 1, "N", "N"],
|
| 191 |
+
[-0.470, 0.286, -0.344, 0., 1, "C", "CA"],
|
| 192 |
+
[-1.868, -0.180, -0.029, 0., 1, "C", "C"],
|
| 193 |
+
[-2.534, 0.415, 0.786, 0., 1, "O", "O"],
|
| 194 |
+
[0.539, -0.580, 0.413, 0., 1, "C", "CB"],
|
| 195 |
+
[1.938, -0.195, 0.004, 0., 1, "C", "CG"],
|
| 196 |
+
[2.109, 0.681, -0.810, 0., 1, "O", "OD1"],
|
| 197 |
+
[2.992, -0.826, 0.543, 0., 1, "O", "OD2"],
|
| 198 |
+
[-2.374, -1.256, -0.652, 0., 0, "O", "OXT"],
|
| 199 |
+
],
|
| 200 |
+
"CYS": [
|
| 201 |
+
[1.585, 0.483, -0.081, 0., 1, "N", "N"],
|
| 202 |
+
[0.141, 0.450, 0.186, 0., 1, "C", "CA"],
|
| 203 |
+
[-0.095, 0.006, 1.606, 0., 1, "C", "C"],
|
| 204 |
+
[0.685, -0.742, 2.143, 0., 1, "O", "O"],
|
| 205 |
+
[-0.533, -0.530, -0.774, 0., 1, "C", "CB"],
|
| 206 |
+
[-0.247, 0.004, -2.484, 0., 1, "S", "SG"],
|
| 207 |
+
[-1.174, 0.443, 2.275, 0., 0, "O", "OXT"],
|
| 208 |
+
],
|
| 209 |
+
"GLN": [
|
| 210 |
+
[1.858, -0.148, 1.125, 0., 1, "N", "N"],
|
| 211 |
+
[0.517, 0.451, 1.112, 0., 1, "C", "CA"],
|
| 212 |
+
[-0.236, 0.022, 2.344, 0., 1, "C", "C"],
|
| 213 |
+
[-0.005, -1.049, 2.851, 0., 1, "O", "O"],
|
| 214 |
+
[-0.236, -0.013, -0.135, 0., 1, "C", "CB"],
|
| 215 |
+
[0.529, 0.421, -1.385, 0., 1, "C", "CG"],
|
| 216 |
+
[-0.213, -0.036, -2.614, 0., 1, "C", "CD"],
|
| 217 |
+
[-1.252, -0.650, -2.500, 0., 1, "O", "OE1"],
|
| 218 |
+
[0.277, 0.236, -3.839, 0., 1, "N", "NE2"],
|
| 219 |
+
[-1.165, 0.831, 2.878, 0., 0, "O", "OXT"],
|
| 220 |
+
],
|
| 221 |
+
"GLU": [
|
| 222 |
+
[1.199, 1.867, -0.117, 0., 1, "N", "N"],
|
| 223 |
+
[1.138, 0.515, 0.453, 0., 1, "C", "CA"],
|
| 224 |
+
[2.364, -0.260, 0.041, 0., 1, "C", "C"],
|
| 225 |
+
[3.010, 0.096, -0.916, 0., 1, "O", "O"],
|
| 226 |
+
[-0.113, -0.200, -0.062, 0., 1, "C", "CB"],
|
| 227 |
+
[-1.360, 0.517, 0.461, 0., 1, "C", "CG"],
|
| 228 |
+
[-2.593, -0.187, -0.046, 0., 1, "C", "CD"],
|
| 229 |
+
[-2.485, -1.161, -0.753, 0., 1, "O", "OE1"],
|
| 230 |
+
[-3.811, 0.269, 0.287, 0., 1, "O", "OE2"],
|
| 231 |
+
[2.737, -1.345, 0.737, 0., 0, "O", "OXT"],
|
| 232 |
+
],
|
| 233 |
+
"GLY": [
|
| 234 |
+
[1.931, 0.090, -0.034, 0., 1, "N", "N"],
|
| 235 |
+
[0.761, -0.799, -0.008, 0., 1, "C", "CA"],
|
| 236 |
+
[-0.498, 0.029, -0.005, 0., 1, "C", "C"],
|
| 237 |
+
[-0.429, 1.235, -0.023, 0., 1, "O", "O"],
|
| 238 |
+
[-1.697, -0.574, 0.018, 0., 0, "O", "OXT"],
|
| 239 |
+
],
|
| 240 |
+
"HIS": [
|
| 241 |
+
[-0.040, -1.210, 0.053, 0., 1, "N", "N"],
|
| 242 |
+
[1.172, -1.709, 0.652, 0., 1, "C", "CA"],
|
| 243 |
+
[1.083, -3.207, 0.905, 0., 1, "C", "C"],
|
| 244 |
+
[0.040, -3.770, 1.222, 0., 1, "O", "O"],
|
| 245 |
+
[1.484, -0.975, 1.962, 0., 1, "C", "CB"],
|
| 246 |
+
[2.940, -1.060, 2.353, 0., 1, "C", "CG"],
|
| 247 |
+
[3.380, -2.075, 3.129, 0., 1, "N", "ND1"],
|
| 248 |
+
[3.960, -0.251, 2.046, 0., 1, "C", "CD2"],
|
| 249 |
+
[4.693, -1.908, 3.317, 0., 1, "C", "CE1"],
|
| 250 |
+
[5.058, -0.801, 2.662, 0., 1, "N", "NE2"],
|
| 251 |
+
[2.247, -3.882, 0.744, 0., 0, "O", "OXT"],
|
| 252 |
+
],
|
| 253 |
+
"ILE": [
|
| 254 |
+
[-1.944, 0.335, -0.343, 0., 1, "N", "N"],
|
| 255 |
+
[-0.487, 0.519, -0.369, 0., 1, "C", "CA"],
|
| 256 |
+
[0.066, -0.032, -1.657, 0., 1, "C", "C"],
|
| 257 |
+
[-0.484, -0.958, -2.203, 0., 1, "O", "O"],
|
| 258 |
+
[0.140, -0.219, 0.814, 0., 1, "C", "CB"],
|
| 259 |
+
[-0.421, 0.341, 2.122, 0., 1, "C", "CG1"],
|
| 260 |
+
[1.658, -0.027, 0.788, 0., 1, "C", "CG2"],
|
| 261 |
+
[0.206, -0.397, 3.305, 0., 1, "C", "CD1"],
|
| 262 |
+
[1.171, 0.504, -2.197, 0., 0, "O", "OXT"],
|
| 263 |
+
],
|
| 264 |
+
"LEU": [
|
| 265 |
+
[-1.661, 0.627, -0.406, 0., 1, "N", "N"],
|
| 266 |
+
[-0.205, 0.441, -0.467, 0., 1, "C", "CA"],
|
| 267 |
+
[0.180, -0.055, -1.836, 0., 1, "C", "C"],
|
| 268 |
+
[-0.591, -0.731, -2.474, 0., 1, "O", "O"],
|
| 269 |
+
[0.221, -0.583, 0.585, 0., 1, "C", "CB"],
|
| 270 |
+
[-0.170, -0.079, 1.976, 0., 1, "C", "CG"],
|
| 271 |
+
[0.256, -1.104, 3.029, 0., 1, "C", "CD1"],
|
| 272 |
+
[0.526, 1.254, 2.250, 0., 1, "C", "CD2"],
|
| 273 |
+
[1.382, 0.254, -2.348, 0., 0, "O", "OXT"],
|
| 274 |
+
],
|
| 275 |
+
"LYS": [
|
| 276 |
+
[1.422, 1.796, 0.198, 0., 1, "N", "N"],
|
| 277 |
+
[1.394, 0.355, 0.484, 0., 1, "C", "CA"],
|
| 278 |
+
[2.657, -0.284, -0.032, 0., 1, "C", "C"],
|
| 279 |
+
[3.316, 0.275, -0.876, 0., 1, "O", "O"],
|
| 280 |
+
[0.184, -0.278, -0.206, 0., 1, "C", "CB"],
|
| 281 |
+
[-1.102, 0.282, 0.407, 0., 1, "C", "CG"],
|
| 282 |
+
[-2.313, -0.351, -0.283, 0., 1, "C", "CD"],
|
| 283 |
+
[-3.598, 0.208, 0.329, 0., 1, "C", "CE"],
|
| 284 |
+
[-4.761, -0.400, -0.332, 0., 1, "N", "NZ"],
|
| 285 |
+
[3.050, -1.476, 0.446, 0., 0, "O", "OXT"],
|
| 286 |
+
],
|
| 287 |
+
"MET": [
|
| 288 |
+
[-1.816, 0.142, -1.166, 0., 1, "N", "N"],
|
| 289 |
+
[-0.392, 0.499, -1.214, 0., 1, "C", "CA"],
|
| 290 |
+
[0.206, 0.002, -2.504, 0., 1, "C", "C"],
|
| 291 |
+
[-0.236, -0.989, -3.033, 0., 1, "O", "O"],
|
| 292 |
+
[0.334, -0.145, -0.032, 0., 1, "C", "CB"],
|
| 293 |
+
[-0.273, 0.359, 1.277, 0., 1, "C", "CG"],
|
| 294 |
+
[0.589, -0.405, 2.678, 0., 1, "S", "SD"],
|
| 295 |
+
[-0.314, 0.353, 4.056, 0., 1, "C", "CE"],
|
| 296 |
+
[1.232, 0.661, -3.066, 0., 0, "O", "OXT"],
|
| 297 |
+
],
|
| 298 |
+
"PHE": [
|
| 299 |
+
[1.317, 0.962, 1.014, 0., 1, "N", "N"],
|
| 300 |
+
[-0.020, 0.426, 1.300, 0., 1, "C", "CA"],
|
| 301 |
+
[-0.109, 0.047, 2.756, 0., 1, "C", "C"],
|
| 302 |
+
[0.879, -0.317, 3.346, 0., 1, "O", "O"],
|
| 303 |
+
[-0.270, -0.809, 0.434, 0., 1, "C", "CB"],
|
| 304 |
+
[-0.181, -0.430, -1.020, 0., 1, "C", "CG"],
|
| 305 |
+
[1.031, -0.498, -1.680, 0., 1, "C", "CD1"],
|
| 306 |
+
[-1.314, -0.018, -1.698, 0., 1, "C", "CD2"],
|
| 307 |
+
[1.112, -0.150, -3.015, 0., 1, "C", "CE1"],
|
| 308 |
+
[-1.231, 0.333, -3.032, 0., 1, "C", "CE2"],
|
| 309 |
+
[-0.018, 0.265, -3.691, 0., 1, "C", "CZ"],
|
| 310 |
+
[-1.286, 0.113, 3.396, 0., 0, "O", "OXT"],
|
| 311 |
+
],
|
| 312 |
+
"PRO": [
|
| 313 |
+
[-0.816, 1.108, 0.254, 0., 1, "N", "N"],
|
| 314 |
+
[0.001, -0.107, 0.509, 0., 1, "C", "CA"],
|
| 315 |
+
[1.408, 0.091, 0.005, 0., 1, "C", "C"],
|
| 316 |
+
[1.650, 0.980, -0.777, 0., 1, "O", "O"],
|
| 317 |
+
[-0.703, -1.227, -0.286, 0., 1, "C", "CB"],
|
| 318 |
+
[-2.163, -0.753, -0.439, 0., 1, "C", "CG"],
|
| 319 |
+
[-2.218, 0.614, 0.276, 0., 1, "C", "CD"],
|
| 320 |
+
[2.391, -0.721, 0.424, 0., 0, "O", "OXT"],
|
| 321 |
+
],
|
| 322 |
+
"SER": [
|
| 323 |
+
[1.525, 0.493, -0.608, 0., 1, "N", "N"],
|
| 324 |
+
[0.100, 0.469, -0.252, 0., 1, "C", "CA"],
|
| 325 |
+
[-0.053, 0.004, 1.173, 0., 1, "C", "C"],
|
| 326 |
+
[0.751, -0.760, 1.649, 0., 1, "O", "O"],
|
| 327 |
+
[-0.642, -0.489, -1.184, 0., 1, "C", "CB"],
|
| 328 |
+
[-0.496, -0.049, -2.535, 0., 1, "O", "OG"],
|
| 329 |
+
[-1.084, 0.440, 1.913, 0., 0, "O", "OXT"],
|
| 330 |
+
],
|
| 331 |
+
"THR": [
|
| 332 |
+
[1.543, -0.702, 0.430, 0., 1, "N", "N"],
|
| 333 |
+
[0.122, -0.706, 0.056, 0., 1, "C", "CA"],
|
| 334 |
+
[-0.038, -0.090, -1.309, 0., 1, "C", "C"],
|
| 335 |
+
[0.732, 0.761, -1.683, 0., 1, "O", "O"],
|
| 336 |
+
[-0.675, 0.104, 1.079, 0., 1, "C", "CB"],
|
| 337 |
+
[-0.193, 1.448, 1.103, 0., 1, "O", "OG1"],
|
| 338 |
+
[-0.511, -0.521, 2.466, 0., 1, "C", "CG2"],
|
| 339 |
+
[-1.039, -0.488, -2.110, 0., 0, "O", "OXT"],
|
| 340 |
+
],
|
| 341 |
+
"TRP": [
|
| 342 |
+
[1.278, 1.121, 2.059, 0., 1, "N", "N"],
|
| 343 |
+
[-0.008, 0.417, 1.970, 0., 1, "C", "CA"],
|
| 344 |
+
[-0.490, 0.076, 3.357, 0., 1, "C", "C"],
|
| 345 |
+
[0.308, -0.130, 4.240, 0., 1, "O", "O"],
|
| 346 |
+
[0.168, -0.868, 1.161, 0., 1, "C", "CB"],
|
| 347 |
+
[0.650, -0.526, -0.225, 0., 1, "C", "CG"],
|
| 348 |
+
[1.928, -0.418, -0.622, 0., 1, "C", "CD1"],
|
| 349 |
+
[-0.186, -0.256, -1.396, 0., 1, "C", "CD2"],
|
| 350 |
+
[1.978, -0.095, -1.951, 0., 1, "N", "NE1"],
|
| 351 |
+
[0.701, 0.014, -2.454, 0., 1, "C", "CE2"],
|
| 352 |
+
[-1.564, -0.210, -1.615, 0., 1, "C", "CE3"],
|
| 353 |
+
[0.190, 0.314, -3.712, 0., 1, "C", "CZ2"],
|
| 354 |
+
[-2.044, 0.086, -2.859, 0., 1, "C", "CZ3"],
|
| 355 |
+
[-1.173, 0.348, -3.907, 0., 1, "C", "CH2"],
|
| 356 |
+
[-1.806, 0.001, 3.610, 0., 0, "O", "OXT"],
|
| 357 |
+
],
|
| 358 |
+
"TYR": [
|
| 359 |
+
[1.320, 0.952, 1.428, 0., 1, "N", "N"],
|
| 360 |
+
[-0.018, 0.429, 1.734, 0., 1, "C", "CA"],
|
| 361 |
+
[-0.103, 0.094, 3.201, 0., 1, "C", "C"],
|
| 362 |
+
[0.886, -0.254, 3.799, 0., 1, "O", "O"],
|
| 363 |
+
[-0.274, -0.831, 0.907, 0., 1, "C", "CB"],
|
| 364 |
+
[-0.189, -0.496, -0.559, 0., 1, "C", "CG"],
|
| 365 |
+
[1.022, -0.589, -1.219, 0., 1, "C", "CD1"],
|
| 366 |
+
[-1.324, -0.102, -1.244, 0., 1, "C", "CD2"],
|
| 367 |
+
[1.103, -0.282, -2.563, 0., 1, "C", "CE1"],
|
| 368 |
+
[-1.247, 0.210, -2.587, 0., 1, "C", "CE2"],
|
| 369 |
+
[-0.032, 0.118, -3.252, 0., 1, "C", "CZ"],
|
| 370 |
+
[0.044, 0.420, -4.574, 0., 1, "O", "OH"],
|
| 371 |
+
[-1.279, 0.184, 3.842, 0., 0, "O", "OXT"],
|
| 372 |
+
],
|
| 373 |
+
"VAL": [
|
| 374 |
+
[1.564, -0.642, 0.454, 0., 1, "N", "N"],
|
| 375 |
+
[0.145, -0.698, 0.079, 0., 1, "C", "CA"],
|
| 376 |
+
[-0.037, -0.093, -1.288, 0., 1, "C", "C"],
|
| 377 |
+
[0.703, 0.784, -1.664, 0., 1, "O", "O"],
|
| 378 |
+
[-0.682, 0.086, 1.098, 0., 1, "C", "CB"],
|
| 379 |
+
[-0.497, -0.528, 2.487, 0., 1, "C", "CG1"],
|
| 380 |
+
[-0.218, 1.543, 1.119, 0., 1, "C", "CG2"],
|
| 381 |
+
[-1.022, -0.529, -2.089, 0., 0, "O", "OXT"],
|
| 382 |
+
],
|
| 383 |
+
"A ": [
|
| 384 |
+
[2.135, -1.141, -5.313, 0., 0, "O", "OP3"],
|
| 385 |
+
[1.024, -0.137, -4.723, 0., 1, "P", "P"],
|
| 386 |
+
[1.633, 1.190, -4.488, 0., 1, "O", "OP1"],
|
| 387 |
+
[-0.183, 0.005, -5.778, 0., 1, "O", "OP2"],
|
| 388 |
+
[0.456, -0.720, -3.334, 0., 1, "O", "O5'"],
|
| 389 |
+
[-0.520, 0.209, -2.863, 0., 1, "C", "C5'"],
|
| 390 |
+
[-1.101, -0.287, -1.538, 0., 1, "C", "C4'"],
|
| 391 |
+
[-0.064, -0.383, -0.538, 0., 1, "O", "O4'"],
|
| 392 |
+
[-2.105, 0.739, -0.969, 0., 1, "C", "C3'"],
|
| 393 |
+
[-3.445, 0.360, -1.287, 0., 1, "O", "O3'"],
|
| 394 |
+
[-1.874, 0.684, 0.558, 0., 1, "C", "C2'"],
|
| 395 |
+
[-3.065, 0.271, 1.231, 0., 1, "O", "O2'"],
|
| 396 |
+
[-0.755, -0.367, 0.729, 0., 1, "C", "C1'"],
|
| 397 |
+
[0.158, 0.029, 1.803, 0., 1, "N", "N9"],
|
| 398 |
+
[1.265, 0.813, 1.672, 0., 1, "C", "C8"],
|
| 399 |
+
[1.843, 0.963, 2.828, 0., 1, "N", "N7"],
|
| 400 |
+
[1.143, 0.292, 3.773, 0., 1, "C", "C5"],
|
| 401 |
+
[1.290, 0.091, 5.156, 0., 1, "C", "C6"],
|
| 402 |
+
[2.344, 0.664, 5.846, 0., 1, "N", "N6"],
|
| 403 |
+
[0.391, -0.656, 5.787, 0., 1, "N", "N1"],
|
| 404 |
+
[-0.617, -1.206, 5.136, 0., 1, "C", "C2"],
|
| 405 |
+
[-0.792, -1.051, 3.841, 0., 1, "N", "N3"],
|
| 406 |
+
[0.056, -0.320, 3.126, 0., 1, "C", "C4"],
|
| 407 |
+
],
|
| 408 |
+
"G ": [
|
| 409 |
+
[-1.945, -1.360, 5.599, 0., 0, "O", "OP3"],
|
| 410 |
+
[-0.911, -0.277, 5.008, 0., 1, "P", "P"],
|
| 411 |
+
[-1.598, 1.022, 4.844, 0., 1, "O", "OP1"],
|
| 412 |
+
[0.325, -0.105, 6.025, 0., 1, "O", "OP2"],
|
| 413 |
+
[-0.365, -0.780, 3.580, 0., 1, "O", "O5'"],
|
| 414 |
+
[0.542, 0.217, 3.109, 0., 1, "C", "C5'"],
|
| 415 |
+
[1.100, -0.200, 1.748, 0., 1, "C", "C4'"],
|
| 416 |
+
[0.033, -0.318, 0.782, 0., 1, "O", "O4'"],
|
| 417 |
+
[2.025, 0.898, 1.182, 0., 1, "C", "C3'"],
|
| 418 |
+
[3.395, 0.582, 1.439, 0., 1, "O", "O3'"],
|
| 419 |
+
[1.741, 0.884, -0.338, 0., 1, "C", "C2'"],
|
| 420 |
+
[2.927, 0.560, -1.066, 0., 1, "O", "O2'"],
|
| 421 |
+
[0.675, -0.220, -0.507, 0., 1, "C", "C1'"],
|
| 422 |
+
[-0.297, 0.162, -1.534, 0., 1, "N", "N9"],
|
| 423 |
+
[-1.440, 0.880, -1.334, 0., 1, "C", "C8"],
|
| 424 |
+
[-2.066, 1.037, -2.464, 0., 1, "N", "N7"],
|
| 425 |
+
[-1.364, 0.431, -3.453, 0., 1, "C", "C5"],
|
| 426 |
+
[-1.556, 0.279, -4.846, 0., 1, "C", "C6"],
|
| 427 |
+
[-2.534, 0.755, -5.397, 0., 1, "O", "O6"],
|
| 428 |
+
[-0.626, -0.401, -5.551, 0., 1, "N", "N1"],
|
| 429 |
+
[0.459, -0.934, -4.923, 0., 1, "C", "C2"],
|
| 430 |
+
[1.384, -1.626, -5.664, 0., 1, "N", "N2"],
|
| 431 |
+
[0.649, -0.800, -3.630, 0., 1, "N", "N3"],
|
| 432 |
+
[-0.226, -0.134, -2.868, 0., 1, "C", "C4"],
|
| 433 |
+
],
|
| 434 |
+
"C ": [
|
| 435 |
+
[2.147, -1.021, -4.678, 0., 0, "O", "OP3"],
|
| 436 |
+
[1.049, -0.039, -4.028, 0., 1, "P", "P"],
|
| 437 |
+
[1.692, 1.237, -3.646, 0., 1, "O", "OP1"],
|
| 438 |
+
[-0.116, 0.246, -5.102, 0., 1, "O", "OP2"],
|
| 439 |
+
[0.415, -0.733, -2.721, 0., 1, "O", "O5'"],
|
| 440 |
+
[-0.546, 0.181, -2.193, 0., 1, "C", "C5'"],
|
| 441 |
+
[-1.189, -0.419, -0.942, 0., 1, "C", "C4'"],
|
| 442 |
+
[-0.190, -0.648, 0.076, 0., 1, "O", "O4'"],
|
| 443 |
+
[-2.178, 0.583, -0.307, 0., 1, "C", "C3'"],
|
| 444 |
+
[-3.518, 0.283, -0.703, 0., 1, "O", "O3'"],
|
| 445 |
+
[-2.001, 0.373, 1.215, 0., 1, "C", "C2'"],
|
| 446 |
+
[-3.228, -0.059, 1.806, 0., 1, "O", "O2'"],
|
| 447 |
+
[-0.924, -0.729, 1.317, 0., 1, "C", "C1'"],
|
| 448 |
+
[-0.036, -0.470, 2.453, 0., 1, "N", "N1"],
|
| 449 |
+
[0.652, 0.683, 2.514, 0., 1, "C", "C2"],
|
| 450 |
+
[0.529, 1.504, 1.620, 0., 1, "O", "O2"],
|
| 451 |
+
[1.467, 0.945, 3.535, 0., 1, "N", "N3"],
|
| 452 |
+
[1.620, 0.070, 4.520, 0., 1, "C", "C4"],
|
| 453 |
+
[2.464, 0.350, 5.569, 0., 1, "N", "N4"],
|
| 454 |
+
[0.916, -1.151, 4.483, 0., 1, "C", "C5"],
|
| 455 |
+
[0.087, -1.399, 3.442, 0., 1, "C", "C6"],
|
| 456 |
+
],
|
| 457 |
+
"U ": [
|
| 458 |
+
[-2.122, 1.033, -4.690, 0., 0, "O", "OP3"],
|
| 459 |
+
[-1.030, 0.047, -4.037, 0., 1, "P", "P"],
|
| 460 |
+
[-1.679, -1.228, -3.660, 0., 1, "O", "OP1"],
|
| 461 |
+
[0.138, -0.241, -5.107, 0., 1, "O", "OP2"],
|
| 462 |
+
[-0.399, 0.736, -2.726, 0., 1, "O", "O5'"],
|
| 463 |
+
[0.557, -0.182, -2.196, 0., 1, "C", "C5'"],
|
| 464 |
+
[1.197, 0.415, -0.942, 0., 1, "C", "C4'"],
|
| 465 |
+
[0.194, 0.645, 0.074, 0., 1, "O", "O4'"],
|
| 466 |
+
[2.181, -0.588, -0.301, 0., 1, "C", "C3'"],
|
| 467 |
+
[3.524, -0.288, -0.686, 0., 1, "O", "O3'"],
|
| 468 |
+
[1.995, -0.383, 1.218, 0., 1, "C", "C2'"],
|
| 469 |
+
[3.219, 0.046, 1.819, 0., 1, "O", "O2'"],
|
| 470 |
+
[0.922, 0.723, 1.319, 0., 1, "C", "C1'"],
|
| 471 |
+
[0.028, 0.464, 2.451, 0., 1, "N", "N1"],
|
| 472 |
+
[-0.690, -0.671, 2.486, 0., 1, "C", "C2"],
|
| 473 |
+
[-0.587, -1.474, 1.580, 0., 1, "O", "O2"],
|
| 474 |
+
[-1.515, -0.936, 3.517, 0., 1, "N", "N3"],
|
| 475 |
+
[-1.641, -0.055, 4.530, 0., 1, "C", "C4"],
|
| 476 |
+
[-2.391, -0.292, 5.460, 0., 1, "O", "O4"],
|
| 477 |
+
[-0.894, 1.146, 4.502, 0., 1, "C", "C5"],
|
| 478 |
+
[-0.070, 1.384, 3.459, 0., 1, "C", "C6"],
|
| 479 |
+
],
|
| 480 |
+
"DA ": [
|
| 481 |
+
[1.845, -1.282, -5.339, 0., 0, "O", "OP3"],
|
| 482 |
+
[0.934, -0.156, -4.636, 0., 1, "P", "P"],
|
| 483 |
+
[1.781, 0.996, -4.255, 0., 1, "O", "OP1"],
|
| 484 |
+
[-0.204, 0.331, -5.665, 0., 1, "O", "OP2"],
|
| 485 |
+
[0.241, -0.771, -3.320, 0., 1, "O", "O5'"],
|
| 486 |
+
[-0.549, 0.270, -2.744, 0., 1, "C", "C5'"],
|
| 487 |
+
[-1.239, -0.251, -1.482, 0., 1, "C", "C4'"],
|
| 488 |
+
[-0.267, -0.564, -0.458, 0., 1, "O", "O4'"],
|
| 489 |
+
[-2.105, 0.859, -0.835, 0., 1, "C", "C3'"],
|
| 490 |
+
[-3.409, 0.895, -1.418, 0., 1, "O", "O3'"],
|
| 491 |
+
[-2.173, 0.398, 0.640, 0., 1, "C", "C2'"],
|
| 492 |
+
[-0.965, -0.545, 0.797, 0., 1, "C", "C1'"],
|
| 493 |
+
[-0.078, -0.047, 1.852, 0., 1, "N", "N9"],
|
| 494 |
+
[0.962, 0.817, 1.689, 0., 1, "C", "C8"],
|
| 495 |
+
[1.535, 1.044, 2.835, 0., 1, "N", "N7"],
|
| 496 |
+
[0.897, 0.346, 3.805, 0., 1, "C", "C5"],
|
| 497 |
+
[1.069, 0.196, 5.191, 0., 1, "C", "C6"],
|
| 498 |
+
[2.079, 0.869, 5.856, 0., 1, "N", "N6"],
|
| 499 |
+
[0.236, -0.603, 5.850, 0., 1, "N", "N1"],
|
| 500 |
+
[-0.729, -1.249, 5.224, 0., 1, "C", "C2"],
|
| 501 |
+
[-0.925, -1.144, 3.927, 0., 1, "N", "N3"],
|
| 502 |
+
[-0.142, -0.368, 3.184, 0., 1, "C", "C4"],
|
| 503 |
+
],
|
| 504 |
+
"DG ": [
|
| 505 |
+
[-1.603, -1.547, 5.624, 0., 0, "O", "OP3"],
|
| 506 |
+
[-0.818, -0.321, 4.935, 0., 1, "P", "P"],
|
| 507 |
+
[-1.774, 0.766, 4.630, 0., 1, "O", "OP1"],
|
| 508 |
+
[0.312, 0.224, 5.941, 0., 1, "O", "OP2"],
|
| 509 |
+
[-0.126, -0.826, 3.572, 0., 1, "O", "O5'"],
|
| 510 |
+
[0.550, 0.300, 3.011, 0., 1, "C", "C5'"],
|
| 511 |
+
[1.233, -0.113, 1.706, 0., 1, "C", "C4'"],
|
| 512 |
+
[0.253, -0.471, 0.705, 0., 1, "O", "O4'"],
|
| 513 |
+
[1.976, 1.091, 1.073, 0., 1, "C", "C3'"],
|
| 514 |
+
[3.294, 1.218, 1.612, 0., 1, "O", "O3'"],
|
| 515 |
+
[2.026, 0.692, -0.421, 0., 1, "C", "C2'"],
|
| 516 |
+
[0.897, -0.345, -0.573, 0., 1, "C", "C1'"],
|
| 517 |
+
[-0.068, 0.111, -1.575, 0., 1, "N", "N9"],
|
| 518 |
+
[-1.172, 0.877, -1.341, 0., 1, "C", "C8"],
|
| 519 |
+
[-1.804, 1.094, -2.458, 0., 1, "N", "N7"],
|
| 520 |
+
[-1.145, 0.482, -3.472, 0., 1, "C", "C5"],
|
| 521 |
+
[-1.361, 0.377, -4.866, 0., 1, "C", "C6"],
|
| 522 |
+
[-2.321, 0.914, -5.391, 0., 1, "O", "O6"],
|
| 523 |
+
[-0.473, -0.327, -5.601, 0., 1, "N", "N1"],
|
| 524 |
+
[0.593, -0.928, -5.003, 0., 1, "C", "C2"],
|
| 525 |
+
[1.474, -1.643, -5.774, 0., 1, "N", "N2"],
|
| 526 |
+
[0.804, -0.839, -3.709, 0., 1, "N", "N3"],
|
| 527 |
+
[-0.027, -0.152, -2.917, 0., 1, "C", "C4"],
|
| 528 |
+
],
|
| 529 |
+
"DC ": [
|
| 530 |
+
[1.941, -1.055, -4.672, 0., 0, "O", "OP3"],
|
| 531 |
+
[0.987, -0.017, -3.894, 0., 1, "P", "P"],
|
| 532 |
+
[1.802, 1.099, -3.365, 0., 1, "O", "OP1"],
|
| 533 |
+
[-0.119, 0.560, -4.910, 0., 1, "O", "OP2"],
|
| 534 |
+
[0.255, -0.772, -2.674, 0., 1, "O", "O5'"],
|
| 535 |
+
[-0.571, 0.196, -2.027, 0., 1, "C", "C5'"],
|
| 536 |
+
[-1.300, -0.459, -0.852, 0., 1, "C", "C4'"],
|
| 537 |
+
[-0.363, -0.863, 0.171, 0., 1, "O", "O4'"],
|
| 538 |
+
[-2.206, 0.569, -0.129, 0., 1, "C", "C3'"],
|
| 539 |
+
[-3.488, 0.649, -0.756, 0., 1, "O", "O3'"],
|
| 540 |
+
[-2.322, -0.040, 1.288, 0., 1, "C", "C2'"],
|
| 541 |
+
[-1.106, -0.981, 1.395, 0., 1, "C", "C1'"],
|
| 542 |
+
[-0.267, -0.584, 2.528, 0., 1, "N", "N1"],
|
| 543 |
+
[0.270, 0.648, 2.563, 0., 1, "C", "C2"],
|
| 544 |
+
[0.052, 1.424, 1.647, 0., 1, "O", "O2"],
|
| 545 |
+
[1.037, 1.035, 3.581, 0., 1, "N", "N3"],
|
| 546 |
+
[1.291, 0.212, 4.589, 0., 1, "C", "C4"],
|
| 547 |
+
[2.085, 0.622, 5.635, 0., 1, "N", "N4"],
|
| 548 |
+
[0.746, -1.088, 4.580, 0., 1, "C", "C5"],
|
| 549 |
+
[-0.035, -1.465, 3.541, 0., 1, "C", "C6"],
|
| 550 |
+
],
|
| 551 |
+
"DT ": [
|
| 552 |
+
[-3.912, -2.311, 1.636, 0., 0, "O", "OP3"],
|
| 553 |
+
[-3.968, -1.665, 3.118, 0., 1, "P", "P"],
|
| 554 |
+
[-4.406, -2.599, 4.208, 0., 1, "O", "OP1"],
|
| 555 |
+
[-4.901, -0.360, 2.920, 0., 1, "O", "OP2"],
|
| 556 |
+
[-2.493, -1.028, 3.315, 0., 1, "O", "O5'"],
|
| 557 |
+
[-2.005, -0.136, 2.327, 0., 1, "C", "C5'"],
|
| 558 |
+
[-0.611, 0.328, 2.728, 0., 1, "C", "C4'"],
|
| 559 |
+
[0.247, -0.829, 2.764, 0., 1, "O", "O4'"],
|
| 560 |
+
[0.008, 1.286, 1.720, 0., 1, "C", "C3'"],
|
| 561 |
+
[0.965, 2.121, 2.368, 0., 1, "O", "O3'"],
|
| 562 |
+
[0.710, 0.360, 0.754, 0., 1, "C", "C2'"],
|
| 563 |
+
[1.157, -0.778, 1.657, 0., 1, "C", "C1'"],
|
| 564 |
+
[1.164, -2.047, 0.989, 0., 1, "N", "N1"],
|
| 565 |
+
[2.333, -2.544, 0.374, 0., 1, "C", "C2"],
|
| 566 |
+
[3.410, -1.945, 0.363, 0., 1, "O", "O2"],
|
| 567 |
+
[2.194, -3.793, -0.240, 0., 1, "N", "N3"],
|
| 568 |
+
[1.047, -4.570, -0.300, 0., 1, "C", "C4"],
|
| 569 |
+
[0.995, -5.663, -0.857, 0., 1, "O", "O4"],
|
| 570 |
+
[-0.143, -3.980, 0.369, 0., 1, "C", "C5"],
|
| 571 |
+
[-1.420, -4.757, 0.347, 0., 1, "C", "C7"],
|
| 572 |
+
[-0.013, -2.784, 0.958, 0., 1, "C", "C6"],
|
| 573 |
+
],
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
standard_ccd_to_ref_atom_name_chars = {
|
| 577 |
+
ccd: [atom_ref_feats[-1] for atom_ref_feats in standard_ccd_to_reference_features_table[ccd]]
|
| 578 |
+
for ccd in standard_ccds if not is_unk(ccd)
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
eye_64 = np.eye(64)
|
| 582 |
+
eye_128 = np.eye(128)
|
| 583 |
+
eye_9 = np.eye(9)
|
| 584 |
+
eye_7 = np.eye(7)
|
| 585 |
+
eye_3 = np.eye(3)
|
| 586 |
+
eye_32 = np.eye(32)
|
| 587 |
+
eye_5 = np.eye(5)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def _get_ref_feat_from_ccd_data(ccd, ref_feat_table):
|
| 591 |
+
ref_feat = np.stack([
|
| 592 |
+
np.concatenate(
|
| 593 |
+
[np.array(atom_ref_feats[:5]), eye_128[get_element_id[atom_ref_feats[5].lower()]],
|
| 594 |
+
*[eye_64[ord(c) - 32] for c in f"{atom_ref_feats[-1]:<4}"]], axis=-1)
|
| 595 |
+
for atom_ref_feats in ref_feat_table[ccd]
|
| 596 |
+
], axis=0)
|
| 597 |
+
|
| 598 |
+
return ref_feat
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
standard_ccd_to_ref_feat = {
|
| 602 |
+
ccd: _get_ref_feat_from_ccd_data(ccd, standard_ccd_to_reference_features_table) for ccd in standard_ccds if
|
| 603 |
+
not is_unk(ccd)
|
| 604 |
+
}
|
PhysDock/data/tools/templates.py
ADDED
|
@@ -0,0 +1,1357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Functions for getting templates and calculating template features."""
|
| 17 |
+
import abc
|
| 18 |
+
import dataclasses
|
| 19 |
+
import datetime
|
| 20 |
+
import functools
|
| 21 |
+
import glob
|
| 22 |
+
import json
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
|
| 27 |
+
import numpy as np
|
| 28 |
+
|
| 29 |
+
from . import parsers, mmcif_parsing
|
| 30 |
+
from . import kalign
|
| 31 |
+
from .utils import to_date
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclasses.dataclass
|
| 35 |
+
class residue_constants:
|
| 36 |
+
restypes = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
|
| 37 |
+
restype_order = {restype: i for i, restype in enumerate(restypes)}
|
| 38 |
+
restypes_with_x = restypes + ["X"]
|
| 39 |
+
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
|
| 40 |
+
atom_type_num = 37
|
| 41 |
+
atom_types = ["N", "CA", "C", "CB", "O", "CG", "CG1", "CG2", "OG", "OG1", "SG", "CD", "CD1", "CD2", "ND1", "ND2",
|
| 42 |
+
"OD1",
|
| 43 |
+
"OD2", "SD", "CE", "CE1", "CE2", "CE3", "NE", "NE1", "NE2", "OE1", "OE2", "CH2", "NH1", "NH2", "OH",
|
| 44 |
+
"CZ",
|
| 45 |
+
"CZ2", "CZ3", "NZ", "OXT", ]
|
| 46 |
+
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
|
| 47 |
+
|
| 48 |
+
HHBLITS_AA_TO_ID = {
|
| 49 |
+
"A": 0, "B": 2, "C": 1, "D": 2, "E": 3, "F": 4, "G": 5, "H": 6, "I": 7, "J": 20, "K": 8, "L": 9, "M": 10,
|
| 50 |
+
"N": 11, "O": 20, "P": 12, "Q": 13, "R": 14, "S": 15, "T": 16, "U": 1, "V": 17, "W": 18, "X": 20, "Y": 19,
|
| 51 |
+
"Z": 3, "-": 21,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def sequence_to_onehot(
|
| 56 |
+
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
|
| 57 |
+
) -> np.ndarray:
|
| 58 |
+
"""Maps the given sequence into a one-hot encoded matrix.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
sequence: An amino acid sequence.
|
| 62 |
+
mapping: A dictionary mapping amino acids to integers.
|
| 63 |
+
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
|
| 64 |
+
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
|
| 65 |
+
amino acid 'X', an error will be thrown. If False, any amino acid not in
|
| 66 |
+
the mapping will throw an error.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
|
| 70 |
+
the sequence.
|
| 71 |
+
|
| 72 |
+
Raises:
|
| 73 |
+
ValueError: If the mapping doesn't contain values from 0 to
|
| 74 |
+
num_unique_aas - 1 without any gaps.
|
| 75 |
+
"""
|
| 76 |
+
num_entries = max(mapping.values()) + 1
|
| 77 |
+
|
| 78 |
+
if sorted(set(mapping.values())) != list(range(num_entries)):
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"The mapping must have values from 0 to num_unique_aas-1 "
|
| 81 |
+
"without any gaps. Got: %s" % sorted(mapping.values())
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
|
| 85 |
+
|
| 86 |
+
for aa_index, aa_type in enumerate(sequence):
|
| 87 |
+
if map_unknown_to_x:
|
| 88 |
+
if aa_type.isalpha() and aa_type.isupper():
|
| 89 |
+
aa_id = mapping.get(aa_type, mapping["X"])
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Invalid character in the sequence: {aa_type}")
|
| 92 |
+
else:
|
| 93 |
+
aa_id = mapping[aa_type]
|
| 94 |
+
one_hot_arr[aa_index, aa_id] = 1
|
| 95 |
+
|
| 96 |
+
return one_hot_arr
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Error(Exception):
|
| 100 |
+
"""Base class for exceptions."""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class NoChainsError(Error):
|
| 104 |
+
"""An error indicating that template mmCIF didn't have any chains."""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class SequenceNotInTemplateError(Error):
|
| 108 |
+
"""An error indicating that template mmCIF didn't contain the sequence."""
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class NoAtomDataInTemplateError(Error):
|
| 112 |
+
"""An error indicating that template mmCIF didn't contain atom positions."""
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class TemplateAtomMaskAllZerosError(Error):
|
| 116 |
+
"""An error indicating that template mmCIF had all atom positions masked."""
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class QueryToTemplateAlignError(Error):
|
| 120 |
+
"""An error indicating that the query can't be aligned to the template."""
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CaDistanceError(Error):
|
| 124 |
+
"""An error indicating that a CA atom distance exceeds a threshold."""
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Prefilter exceptions.
|
| 128 |
+
class PrefilterError(Exception):
|
| 129 |
+
"""A base class for template prefilter exceptions."""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class DateError(PrefilterError):
|
| 133 |
+
"""An error indicating that the hit date was after the max allowed date."""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class AlignRatioError(PrefilterError):
|
| 137 |
+
"""An error indicating that the hit align ratio to the query was too small."""
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DuplicateError(PrefilterError):
|
| 141 |
+
"""An error indicating that the hit was an exact subsequence of the query."""
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class LengthError(PrefilterError):
|
| 145 |
+
"""An error indicating that the hit was too short."""
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
TEMPLATE_FEATURES = {
|
| 149 |
+
"template_aatype": np.int64,
|
| 150 |
+
"template_all_atom_masks": np.float32,
|
| 151 |
+
"template_all_atom_positions": np.float32,
|
| 152 |
+
"template_domain_names": object,
|
| 153 |
+
"template_sequence": object,
|
| 154 |
+
"template_sum_probs": np.float32,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def empty_template_feats(n_res):
|
| 159 |
+
return {
|
| 160 |
+
"template_aatype": np.zeros(
|
| 161 |
+
(0, n_res, len(residue_constants.restypes_with_x_and_gap)),
|
| 162 |
+
np.float32
|
| 163 |
+
),
|
| 164 |
+
"template_all_atom_masks": np.zeros(
|
| 165 |
+
(0, n_res, residue_constants.atom_type_num), np.float32
|
| 166 |
+
),
|
| 167 |
+
"template_all_atom_positions": np.zeros(
|
| 168 |
+
(0, n_res, residue_constants.atom_type_num, 3), np.float32
|
| 169 |
+
),
|
| 170 |
+
"template_domain_names": np.array([''.encode()], dtype=np.object),
|
| 171 |
+
"template_sequence": np.array([''.encode()], dtype=np.object),
|
| 172 |
+
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
|
| 177 |
+
"""Returns PDB id and chain id for an HHSearch Hit."""
|
| 178 |
+
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
|
| 179 |
+
id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
|
| 180 |
+
if not id_match:
|
| 181 |
+
raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
|
| 182 |
+
pdb_id, chain_id = id_match.group(0).split("_")
|
| 183 |
+
return pdb_id.lower(), chain_id
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _is_after_cutoff(
|
| 187 |
+
pdb_id: str,
|
| 188 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 189 |
+
release_date_cutoff: Optional[datetime.datetime],
|
| 190 |
+
) -> bool:
|
| 191 |
+
"""Checks if the template date is after the release date cutoff.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
pdb_id: 4 letter pdb code.
|
| 195 |
+
release_dates: Dictionary mapping PDB ids to their structure release dates.
|
| 196 |
+
release_date_cutoff: Max release date that is valid for this query.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
True if the template release date is after the cutoff, False otherwise.
|
| 200 |
+
"""
|
| 201 |
+
pdb_id_upper = pdb_id.upper()
|
| 202 |
+
if release_date_cutoff is None:
|
| 203 |
+
raise ValueError("The release_date_cutoff must not be None.")
|
| 204 |
+
if pdb_id_upper in release_dates:
|
| 205 |
+
return release_dates[pdb_id_upper] > release_date_cutoff
|
| 206 |
+
else:
|
| 207 |
+
# Since this is just a quick prefilter to reduce the number of mmCIF files
|
| 208 |
+
# we need to parse, we don't have to worry about returning True here.
|
| 209 |
+
logging.info(
|
| 210 |
+
"Template structure not in release dates dict: %s", pdb_id
|
| 211 |
+
)
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
|
| 216 |
+
"""Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
|
| 217 |
+
obsolete_new = {}
|
| 218 |
+
obsolete_keys = obsolete_mapping.keys()
|
| 219 |
+
|
| 220 |
+
def _new_target(k):
|
| 221 |
+
v = obsolete_mapping[k]
|
| 222 |
+
if v in obsolete_keys:
|
| 223 |
+
return _new_target(v)
|
| 224 |
+
return v
|
| 225 |
+
|
| 226 |
+
for k in obsolete_keys:
|
| 227 |
+
obsolete_new[k] = _new_target(k)
|
| 228 |
+
|
| 229 |
+
return obsolete_new
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
|
| 233 |
+
"""Parses the data file from PDB that lists which PDB ids are obsolete."""
|
| 234 |
+
with open(obsolete_file_path) as f:
|
| 235 |
+
result = {}
|
| 236 |
+
for line in f:
|
| 237 |
+
line = line.strip()
|
| 238 |
+
# We skip obsolete entries that don't contain a mapping to a new entry.
|
| 239 |
+
if line.startswith("OBSLTE") and len(line) > 30:
|
| 240 |
+
# Format: Date From To
|
| 241 |
+
# 'OBSLTE 31-JUL-94 116L 216L'
|
| 242 |
+
from_id = line[20:24].lower()
|
| 243 |
+
to_id = line[29:33].lower()
|
| 244 |
+
result[from_id] = to_id
|
| 245 |
+
return _replace_obsolete_references(result)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
|
| 249 |
+
dates = {}
|
| 250 |
+
for f in os.listdir(mmcif_dir):
|
| 251 |
+
if f.endswith(".cif"):
|
| 252 |
+
path = os.path.join(mmcif_dir, f)
|
| 253 |
+
with open(path, "r") as fp:
|
| 254 |
+
mmcif_string = fp.read()
|
| 255 |
+
|
| 256 |
+
file_id = os.path.splitext(f)[0]
|
| 257 |
+
mmcif = mmcif_parsing.parse(
|
| 258 |
+
file_id=file_id, mmcif_string=mmcif_string
|
| 259 |
+
)
|
| 260 |
+
if mmcif.mmcif_object is None:
|
| 261 |
+
logging.info(f"Failed to parse {f}. Skipping...")
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
mmcif = mmcif.mmcif_object
|
| 265 |
+
release_date = mmcif.header["release_date"]
|
| 266 |
+
|
| 267 |
+
dates[file_id] = release_date
|
| 268 |
+
|
| 269 |
+
with open(out_path, "r") as fp:
|
| 270 |
+
fp.write(json.dumps(dates))
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
|
| 274 |
+
"""Parses release dates file, returns a mapping from PDBs to release dates."""
|
| 275 |
+
with open(path, "r") as fp:
|
| 276 |
+
data = json.load(fp)
|
| 277 |
+
|
| 278 |
+
return {
|
| 279 |
+
pdb.upper(): to_date(v)
|
| 280 |
+
for pdb, d in data.items()
|
| 281 |
+
for k, v in d.items()
|
| 282 |
+
if k == "release_date"
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _assess_hhsearch_hit(
|
| 287 |
+
hit: parsers.TemplateHit,
|
| 288 |
+
hit_pdb_code: str,
|
| 289 |
+
query_sequence: str,
|
| 290 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 291 |
+
release_date_cutoff: datetime.datetime,
|
| 292 |
+
max_subsequence_ratio: float = 0.95,
|
| 293 |
+
min_align_ratio: float = 0.1,
|
| 294 |
+
) -> bool:
|
| 295 |
+
"""Determines if template is valid (without parsing the template mmcif file).
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
hit: HhrHit for the template.
|
| 299 |
+
hit_pdb_code: The 4 letter pdb code of the template hit. This might be
|
| 300 |
+
different from the value in the actual hit since the original pdb might
|
| 301 |
+
have become obsolete.
|
| 302 |
+
query_sequence: Amino acid sequence of the query.
|
| 303 |
+
release_dates: Dictionary mapping pdb codes to their structure release
|
| 304 |
+
dates.
|
| 305 |
+
release_date_cutoff: Max release date that is valid for this query.
|
| 306 |
+
max_subsequence_ratio: Exclude any exact matches with this much overlap.
|
| 307 |
+
min_align_ratio: Minimum overlap between the template and query.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
True if the hit passed the prefilter. Raises an exception otherwise.
|
| 311 |
+
|
| 312 |
+
Raises:
|
| 313 |
+
DateError: If the hit date was after the max allowed date.
|
| 314 |
+
AlignRatioError: If the hit align ratio to the query was too small.
|
| 315 |
+
DuplicateError: If the hit was an exact subsequence of the query.
|
| 316 |
+
LengthError: If the hit was too short.
|
| 317 |
+
"""
|
| 318 |
+
aligned_cols = hit.aligned_cols
|
| 319 |
+
align_ratio = aligned_cols / len(query_sequence)
|
| 320 |
+
|
| 321 |
+
template_sequence = hit.hit_sequence.replace("-", "")
|
| 322 |
+
length_ratio = float(len(template_sequence)) / len(query_sequence)
|
| 323 |
+
|
| 324 |
+
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
|
| 325 |
+
date = release_dates[hit_pdb_code.upper()]
|
| 326 |
+
raise DateError(
|
| 327 |
+
f"Date ({date}) > max template date "
|
| 328 |
+
f"({release_date_cutoff})."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if align_ratio <= min_align_ratio:
|
| 332 |
+
raise AlignRatioError(
|
| 333 |
+
"Proportion of residues aligned to query too small. "
|
| 334 |
+
f"Align ratio: {align_ratio}."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Check whether the template is a large subsequence or duplicate of original
|
| 338 |
+
# query. This can happen due to duplicate entries in the PDB database.
|
| 339 |
+
duplicate = (
|
| 340 |
+
template_sequence in query_sequence
|
| 341 |
+
and length_ratio > max_subsequence_ratio
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if duplicate:
|
| 345 |
+
raise DuplicateError(
|
| 346 |
+
"Template is an exact subsequence of query with large "
|
| 347 |
+
f"coverage. Length ratio: {length_ratio}."
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if len(template_sequence) < 10:
|
| 351 |
+
raise LengthError(
|
| 352 |
+
f"Template too short. Length: {len(template_sequence)}."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
return True
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _find_template_in_pdb(
|
| 359 |
+
template_chain_id: str,
|
| 360 |
+
template_sequence: str,
|
| 361 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 362 |
+
) -> Tuple[str, str, int]:
|
| 363 |
+
"""Tries to find the template chain in the given pdb file.
|
| 364 |
+
|
| 365 |
+
This method tries the three following things in order:
|
| 366 |
+
1. Tries if there is an exact match in both the chain ID and the sequence.
|
| 367 |
+
If yes, the chain sequence is returned. Otherwise:
|
| 368 |
+
2. Tries if there is an exact match only in the sequence.
|
| 369 |
+
If yes, the chain sequence is returned. Otherwise:
|
| 370 |
+
3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
|
| 371 |
+
If yes, the chain sequence is returned.
|
| 372 |
+
If none of these succeed, a SequenceNotInTemplateError is thrown.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
template_chain_id: The template chain ID.
|
| 376 |
+
template_sequence: The template chain sequence.
|
| 377 |
+
mmcif_object: The PDB object to search for the template in.
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
A tuple with:
|
| 381 |
+
* The chain sequence that was found to match the template in the PDB object.
|
| 382 |
+
* The ID of the chain that is being returned.
|
| 383 |
+
* The offset where the template sequence starts in the chain sequence.
|
| 384 |
+
|
| 385 |
+
Raises:
|
| 386 |
+
SequenceNotInTemplateError: If no match is found after the steps described
|
| 387 |
+
above.
|
| 388 |
+
"""
|
| 389 |
+
# Try if there is an exact match in both the chain ID and the (sub)sequence.
|
| 390 |
+
pdb_id = mmcif_object.file_id
|
| 391 |
+
chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
|
| 392 |
+
if chain_sequence and (template_sequence in chain_sequence):
|
| 393 |
+
logging.info(
|
| 394 |
+
"Found an exact template match %s_%s.", pdb_id, template_chain_id
|
| 395 |
+
)
|
| 396 |
+
mapping_offset = chain_sequence.find(template_sequence)
|
| 397 |
+
return chain_sequence, template_chain_id, mapping_offset
|
| 398 |
+
|
| 399 |
+
# Try if there is an exact match in the (sub)sequence only.
|
| 400 |
+
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
|
| 401 |
+
if chain_sequence and (template_sequence in chain_sequence):
|
| 402 |
+
logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
|
| 403 |
+
mapping_offset = chain_sequence.find(template_sequence)
|
| 404 |
+
return chain_sequence, chain_id, mapping_offset
|
| 405 |
+
|
| 406 |
+
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
|
| 407 |
+
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
|
| 408 |
+
regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
|
| 409 |
+
regex = re.compile("".join(regex))
|
| 410 |
+
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
|
| 411 |
+
match = re.search(regex, chain_sequence)
|
| 412 |
+
if match:
|
| 413 |
+
logging.info(
|
| 414 |
+
"Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
|
| 415 |
+
)
|
| 416 |
+
mapping_offset = match.start()
|
| 417 |
+
return chain_sequence, chain_id, mapping_offset
|
| 418 |
+
|
| 419 |
+
# No hits, raise an error.
|
| 420 |
+
raise SequenceNotInTemplateError(
|
| 421 |
+
"Could not find the template sequence in %s_%s. Template sequence: %s, "
|
| 422 |
+
"chain_to_seqres: %s"
|
| 423 |
+
% (
|
| 424 |
+
pdb_id,
|
| 425 |
+
template_chain_id,
|
| 426 |
+
template_sequence,
|
| 427 |
+
mmcif_object.chain_to_seqres,
|
| 428 |
+
)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _realign_pdb_template_to_query(
|
| 433 |
+
old_template_sequence: str,
|
| 434 |
+
template_chain_id: str,
|
| 435 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 436 |
+
old_mapping: Mapping[int, int],
|
| 437 |
+
kalign_binary_path: str,
|
| 438 |
+
) -> Tuple[str, Mapping[int, int]]:
|
| 439 |
+
"""Aligns template from the mmcif_object to the query.
|
| 440 |
+
|
| 441 |
+
In case PDB70 contains a different version of the template sequence, we need
|
| 442 |
+
to perform a realignment to the actual sequence that is in the mmCIF file.
|
| 443 |
+
This method performs such realignment, but returns the new sequence and
|
| 444 |
+
mapping only if the sequence in the mmCIF file is 90% identical to the old
|
| 445 |
+
sequence.
|
| 446 |
+
|
| 447 |
+
Note that the old_template_sequence comes from the hit, and contains only that
|
| 448 |
+
part of the chain that matches with the query while the new_template_sequence
|
| 449 |
+
is the full chain.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
old_template_sequence: The template sequence that was returned by the PDB
|
| 453 |
+
template search (typically done using HHSearch).
|
| 454 |
+
template_chain_id: The template chain id was returned by the PDB template
|
| 455 |
+
search (typically done using HHSearch). This is used to find the right
|
| 456 |
+
chain in the mmcif_object chain_to_seqres mapping.
|
| 457 |
+
mmcif_object: A mmcif_object which holds the actual template data.
|
| 458 |
+
old_mapping: A mapping from the query sequence to the template sequence.
|
| 459 |
+
This mapping will be used to compute the new mapping from the query
|
| 460 |
+
sequence to the actual mmcif_object template sequence by aligning the
|
| 461 |
+
old_template_sequence and the actual template sequence.
|
| 462 |
+
kalign_binary_path: The path to a kalign executable.
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
A tuple (new_template_sequence, new_query_to_template_mapping) where:
|
| 466 |
+
* new_template_sequence is the actual template sequence that was found in
|
| 467 |
+
the mmcif_object.
|
| 468 |
+
* new_query_to_template_mapping is the new mapping from the query to the
|
| 469 |
+
actual template found in the mmcif_object.
|
| 470 |
+
|
| 471 |
+
Raises:
|
| 472 |
+
QueryToTemplateAlignError:
|
| 473 |
+
* If there was an error thrown by the alignment tool.
|
| 474 |
+
* Or if the actual template sequence differs by more than 10% from the
|
| 475 |
+
old_template_sequence.
|
| 476 |
+
"""
|
| 477 |
+
aligner = kalign.Kalign(binary_path=kalign_binary_path)
|
| 478 |
+
new_template_sequence = mmcif_object.chain_to_seqres.get(
|
| 479 |
+
template_chain_id, ""
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# Sometimes the template chain id is unknown. But if there is only a single
|
| 483 |
+
# sequence within the mmcif_object, it is safe to assume it is that one.
|
| 484 |
+
if not new_template_sequence:
|
| 485 |
+
if len(mmcif_object.chain_to_seqres) == 1:
|
| 486 |
+
logging.info(
|
| 487 |
+
"Could not find %s in %s, but there is only 1 sequence, so "
|
| 488 |
+
"using that one.",
|
| 489 |
+
template_chain_id,
|
| 490 |
+
mmcif_object.file_id,
|
| 491 |
+
)
|
| 492 |
+
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
|
| 493 |
+
0
|
| 494 |
+
]
|
| 495 |
+
else:
|
| 496 |
+
raise QueryToTemplateAlignError(
|
| 497 |
+
f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
|
| 498 |
+
"If there are no mmCIF parsing errors, it is possible it was not a "
|
| 499 |
+
"protein chain."
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
try:
|
| 503 |
+
parsed_a3m = parsers.parse_a3m(
|
| 504 |
+
aligner.align([old_template_sequence, new_template_sequence])
|
| 505 |
+
)
|
| 506 |
+
old_aligned_template, new_aligned_template = parsed_a3m.sequences
|
| 507 |
+
except Exception as e:
|
| 508 |
+
raise QueryToTemplateAlignError(
|
| 509 |
+
"Could not align old template %s to template %s (%s_%s). Error: %s"
|
| 510 |
+
% (
|
| 511 |
+
old_template_sequence,
|
| 512 |
+
new_template_sequence,
|
| 513 |
+
mmcif_object.file_id,
|
| 514 |
+
template_chain_id,
|
| 515 |
+
str(e),
|
| 516 |
+
)
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
logging.info(
|
| 520 |
+
"Old aligned template: %s\nNew aligned template: %s",
|
| 521 |
+
old_aligned_template,
|
| 522 |
+
new_aligned_template,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
old_to_new_template_mapping = {}
|
| 526 |
+
old_template_index = -1
|
| 527 |
+
new_template_index = -1
|
| 528 |
+
num_same = 0
|
| 529 |
+
for old_template_aa, new_template_aa in zip(
|
| 530 |
+
old_aligned_template, new_aligned_template
|
| 531 |
+
):
|
| 532 |
+
if old_template_aa != "-":
|
| 533 |
+
old_template_index += 1
|
| 534 |
+
if new_template_aa != "-":
|
| 535 |
+
new_template_index += 1
|
| 536 |
+
if old_template_aa != "-" and new_template_aa != "-":
|
| 537 |
+
old_to_new_template_mapping[old_template_index] = new_template_index
|
| 538 |
+
if old_template_aa == new_template_aa:
|
| 539 |
+
num_same += 1
|
| 540 |
+
|
| 541 |
+
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
|
| 542 |
+
if (
|
| 543 |
+
float(num_same)
|
| 544 |
+
/ min(len(old_template_sequence), len(new_template_sequence))
|
| 545 |
+
< 0.9
|
| 546 |
+
):
|
| 547 |
+
raise QueryToTemplateAlignError(
|
| 548 |
+
"Insufficient similarity of the sequence in the database: %s to the "
|
| 549 |
+
"actual sequence in the mmCIF file %s_%s: %s. We require at least "
|
| 550 |
+
"90 %% similarity wrt to the shorter of the sequences. This is not a "
|
| 551 |
+
"problem unless you think this is a template that should be included."
|
| 552 |
+
% (
|
| 553 |
+
old_template_sequence,
|
| 554 |
+
mmcif_object.file_id,
|
| 555 |
+
template_chain_id,
|
| 556 |
+
new_template_sequence,
|
| 557 |
+
)
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
new_query_to_template_mapping = {}
|
| 561 |
+
for query_index, old_template_index in old_mapping.items():
|
| 562 |
+
new_query_to_template_mapping[
|
| 563 |
+
query_index
|
| 564 |
+
] = old_to_new_template_mapping.get(old_template_index, -1)
|
| 565 |
+
|
| 566 |
+
new_template_sequence = new_template_sequence.replace("-", "")
|
| 567 |
+
|
| 568 |
+
return new_template_sequence, new_query_to_template_mapping
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _check_residue_distances(
|
| 572 |
+
all_positions: np.ndarray,
|
| 573 |
+
all_positions_mask: np.ndarray,
|
| 574 |
+
max_ca_ca_distance: float,
|
| 575 |
+
):
|
| 576 |
+
"""Checks if the distance between unmasked neighbor residues is ok."""
|
| 577 |
+
ca_position = residue_constants.atom_order["CA"]
|
| 578 |
+
prev_is_unmasked = False
|
| 579 |
+
prev_calpha = None
|
| 580 |
+
for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
|
| 581 |
+
this_is_unmasked = bool(mask[ca_position])
|
| 582 |
+
if this_is_unmasked:
|
| 583 |
+
this_calpha = coords[ca_position]
|
| 584 |
+
if prev_is_unmasked:
|
| 585 |
+
distance = np.linalg.norm(this_calpha - prev_calpha)
|
| 586 |
+
if distance > max_ca_ca_distance:
|
| 587 |
+
raise CaDistanceError(
|
| 588 |
+
"The distance between residues %d and %d is %f > limit %f."
|
| 589 |
+
% (i, i + 1, distance, max_ca_ca_distance)
|
| 590 |
+
)
|
| 591 |
+
prev_calpha = this_calpha
|
| 592 |
+
prev_is_unmasked = this_is_unmasked
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def _get_atom_positions(
|
| 596 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 597 |
+
auth_chain_id: str,
|
| 598 |
+
max_ca_ca_distance: float,
|
| 599 |
+
_zero_center_positions: bool = False,
|
| 600 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 601 |
+
"""Gets atom positions and mask from a list of Biopython Residues."""
|
| 602 |
+
coords_with_mask = mmcif_parsing.get_atom_coords(
|
| 603 |
+
mmcif_object=mmcif_object,
|
| 604 |
+
chain_id=auth_chain_id,
|
| 605 |
+
_zero_center_positions=_zero_center_positions,
|
| 606 |
+
)
|
| 607 |
+
all_atom_positions, all_atom_mask = coords_with_mask
|
| 608 |
+
_check_residue_distances(
|
| 609 |
+
all_atom_positions, all_atom_mask, max_ca_ca_distance
|
| 610 |
+
)
|
| 611 |
+
return all_atom_positions, all_atom_mask
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def _extract_template_features(
|
| 615 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 616 |
+
pdb_id: str,
|
| 617 |
+
mapping: Mapping[int, int],
|
| 618 |
+
template_sequence: str,
|
| 619 |
+
query_sequence: str,
|
| 620 |
+
template_chain_id: str,
|
| 621 |
+
kalign_binary_path: str,
|
| 622 |
+
_zero_center_positions: bool = True,
|
| 623 |
+
) -> Tuple[Dict[str, Any], Optional[str]]:
|
| 624 |
+
"""Parses atom positions in the target structure and aligns with the query.
|
| 625 |
+
|
| 626 |
+
Atoms for each residue in the template structure are indexed to coincide
|
| 627 |
+
with their corresponding residue in the query sequence, according to the
|
| 628 |
+
alignment mapping provided.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
mmcif_object: mmcif_parsing.MmcifObject representing the template.
|
| 632 |
+
pdb_id: PDB code for the template.
|
| 633 |
+
mapping: Dictionary mapping indices in the query sequence to indices in
|
| 634 |
+
the template sequence.
|
| 635 |
+
template_sequence: String describing the amino acid sequence for the
|
| 636 |
+
template protein.
|
| 637 |
+
query_sequence: String describing the amino acid sequence for the query
|
| 638 |
+
protein.
|
| 639 |
+
template_chain_id: String ID describing which chain in the structure proto
|
| 640 |
+
should be used.
|
| 641 |
+
kalign_binary_path: The path to a kalign executable used for template
|
| 642 |
+
realignment.
|
| 643 |
+
|
| 644 |
+
Returns:
|
| 645 |
+
A tuple with:
|
| 646 |
+
* A dictionary containing the extra features derived from the template
|
| 647 |
+
protein structure.
|
| 648 |
+
* A warning message if the hit was realigned to the actual mmCIF sequence.
|
| 649 |
+
Otherwise None.
|
| 650 |
+
|
| 651 |
+
Raises:
|
| 652 |
+
NoChainsError: If the mmcif object doesn't contain any chains.
|
| 653 |
+
SequenceNotInTemplateError: If the given chain id / sequence can't
|
| 654 |
+
be found in the mmcif object.
|
| 655 |
+
QueryToTemplateAlignError: If the actual template in the mmCIF file
|
| 656 |
+
can't be aligned to the query.
|
| 657 |
+
NoAtomDataInTemplateError: If the mmcif object doesn't contain
|
| 658 |
+
atom positions.
|
| 659 |
+
TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
|
| 660 |
+
unmasked residues.
|
| 661 |
+
"""
|
| 662 |
+
if mmcif_object is None or not mmcif_object.chain_to_seqres:
|
| 663 |
+
raise NoChainsError(
|
| 664 |
+
"No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
warning = None
|
| 668 |
+
try:
|
| 669 |
+
seqres, chain_id, mapping_offset = _find_template_in_pdb(
|
| 670 |
+
template_chain_id=template_chain_id,
|
| 671 |
+
template_sequence=template_sequence,
|
| 672 |
+
mmcif_object=mmcif_object,
|
| 673 |
+
)
|
| 674 |
+
except SequenceNotInTemplateError:
|
| 675 |
+
# If PDB70 contains a different version of the template, we use the sequence
|
| 676 |
+
# from the mmcif_object.
|
| 677 |
+
chain_id = template_chain_id
|
| 678 |
+
warning = (
|
| 679 |
+
f"The exact sequence {template_sequence} was not found in "
|
| 680 |
+
f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
|
| 681 |
+
)
|
| 682 |
+
logging.warning(warning)
|
| 683 |
+
# This throws an exception if it fails to realign the hit.
|
| 684 |
+
seqres, mapping = _realign_pdb_template_to_query(
|
| 685 |
+
old_template_sequence=template_sequence,
|
| 686 |
+
template_chain_id=template_chain_id,
|
| 687 |
+
mmcif_object=mmcif_object,
|
| 688 |
+
old_mapping=mapping,
|
| 689 |
+
kalign_binary_path=kalign_binary_path,
|
| 690 |
+
)
|
| 691 |
+
logging.info(
|
| 692 |
+
"Sequence in %s_%s: %s successfully realigned to %s",
|
| 693 |
+
pdb_id,
|
| 694 |
+
chain_id,
|
| 695 |
+
template_sequence,
|
| 696 |
+
seqres,
|
| 697 |
+
)
|
| 698 |
+
# The template sequence changed.
|
| 699 |
+
template_sequence = seqres
|
| 700 |
+
# No mapping offset, the query is aligned to the actual sequence.
|
| 701 |
+
mapping_offset = 0
|
| 702 |
+
|
| 703 |
+
try:
|
| 704 |
+
# Essentially set to infinity - we don't want to reject templates unless
|
| 705 |
+
# they're really really bad.
|
| 706 |
+
all_atom_positions, all_atom_mask = _get_atom_positions(
|
| 707 |
+
mmcif_object,
|
| 708 |
+
chain_id,
|
| 709 |
+
max_ca_ca_distance=150.0,
|
| 710 |
+
_zero_center_positions=_zero_center_positions,
|
| 711 |
+
)
|
| 712 |
+
except (CaDistanceError, KeyError) as ex:
|
| 713 |
+
raise NoAtomDataInTemplateError(
|
| 714 |
+
"Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
|
| 715 |
+
) from ex
|
| 716 |
+
|
| 717 |
+
all_atom_positions = np.split(
|
| 718 |
+
all_atom_positions, all_atom_positions.shape[0]
|
| 719 |
+
)
|
| 720 |
+
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
|
| 721 |
+
|
| 722 |
+
output_templates_sequence = []
|
| 723 |
+
templates_all_atom_positions = []
|
| 724 |
+
templates_all_atom_masks = []
|
| 725 |
+
|
| 726 |
+
for _ in query_sequence:
|
| 727 |
+
# Residues in the query_sequence that are not in the template_sequence:
|
| 728 |
+
templates_all_atom_positions.append(
|
| 729 |
+
np.zeros((residue_constants.atom_type_num, 3))
|
| 730 |
+
)
|
| 731 |
+
templates_all_atom_masks.append(
|
| 732 |
+
np.zeros(residue_constants.atom_type_num)
|
| 733 |
+
)
|
| 734 |
+
output_templates_sequence.append("-")
|
| 735 |
+
|
| 736 |
+
for k, v in mapping.items():
|
| 737 |
+
template_index = v + mapping_offset
|
| 738 |
+
templates_all_atom_positions[k] = all_atom_positions[template_index][0]
|
| 739 |
+
templates_all_atom_masks[k] = all_atom_masks[template_index][0]
|
| 740 |
+
output_templates_sequence[k] = template_sequence[v]
|
| 741 |
+
|
| 742 |
+
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
|
| 743 |
+
if np.sum(templates_all_atom_masks) < 5:
|
| 744 |
+
raise TemplateAtomMaskAllZerosError(
|
| 745 |
+
"Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
|
| 746 |
+
% (
|
| 747 |
+
pdb_id,
|
| 748 |
+
chain_id,
|
| 749 |
+
min(mapping.values()) + mapping_offset,
|
| 750 |
+
max(mapping.values()) + mapping_offset,
|
| 751 |
+
)
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
output_templates_sequence = "".join(output_templates_sequence)
|
| 755 |
+
|
| 756 |
+
templates_aatype = residue_constants.sequence_to_onehot(
|
| 757 |
+
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
return (
|
| 761 |
+
{
|
| 762 |
+
"template_all_atom_positions": np.array(
|
| 763 |
+
templates_all_atom_positions
|
| 764 |
+
),
|
| 765 |
+
"template_all_atom_masks": np.array(templates_all_atom_masks),
|
| 766 |
+
"template_sequence": output_templates_sequence.encode(),
|
| 767 |
+
"template_aatype": np.array(templates_aatype),
|
| 768 |
+
"template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
|
| 769 |
+
},
|
| 770 |
+
warning,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def _build_query_to_hit_index_mapping(
|
| 775 |
+
hit_query_sequence: str,
|
| 776 |
+
hit_sequence: str,
|
| 777 |
+
indices_hit: Sequence[int],
|
| 778 |
+
indices_query: Sequence[int],
|
| 779 |
+
original_query_sequence: str,
|
| 780 |
+
) -> Mapping[int, int]:
|
| 781 |
+
"""Gets mapping from indices in original query sequence to indices in the hit.
|
| 782 |
+
|
| 783 |
+
hit_query_sequence and hit_sequence are two aligned sequences containing gap
|
| 784 |
+
characters. hit_query_sequence contains only the part of the original query
|
| 785 |
+
sequence that matched the hit. When interpreting the indices from the .hhr, we
|
| 786 |
+
need to correct for this to recover a mapping from original query sequence to
|
| 787 |
+
the hit sequence.
|
| 788 |
+
|
| 789 |
+
Args:
|
| 790 |
+
hit_query_sequence: The portion of the query sequence that is in the .hhr
|
| 791 |
+
hit
|
| 792 |
+
hit_sequence: The portion of the hit sequence that is in the .hhr
|
| 793 |
+
indices_hit: The indices for each aminoacid relative to the hit sequence
|
| 794 |
+
indices_query: The indices for each aminoacid relative to the original query
|
| 795 |
+
sequence
|
| 796 |
+
original_query_sequence: String describing the original query sequence.
|
| 797 |
+
|
| 798 |
+
Returns:
|
| 799 |
+
Dictionary with indices in the original query sequence as keys and indices
|
| 800 |
+
in the hit sequence as values.
|
| 801 |
+
"""
|
| 802 |
+
# If the hit is empty (no aligned residues), return empty mapping
|
| 803 |
+
if not hit_query_sequence:
|
| 804 |
+
return {}
|
| 805 |
+
|
| 806 |
+
# Remove gaps and find the offset of hit.query relative to original query.
|
| 807 |
+
hhsearch_query_sequence = hit_query_sequence.replace("-", "")
|
| 808 |
+
hit_sequence = hit_sequence.replace("-", "")
|
| 809 |
+
hhsearch_query_offset = original_query_sequence.find(
|
| 810 |
+
hhsearch_query_sequence
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
|
| 814 |
+
min_idx = min(x for x in indices_hit if x > -1)
|
| 815 |
+
fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
|
| 816 |
+
|
| 817 |
+
min_idx = min(x for x in indices_query if x > -1)
|
| 818 |
+
fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
|
| 819 |
+
|
| 820 |
+
# Zip the corrected indices, ignore case where both seqs have gap characters.
|
| 821 |
+
mapping = {}
|
| 822 |
+
for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
|
| 823 |
+
if q_t != -1 and q_i != -1:
|
| 824 |
+
if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
|
| 825 |
+
original_query_sequence
|
| 826 |
+
):
|
| 827 |
+
continue
|
| 828 |
+
mapping[q_i + hhsearch_query_offset] = q_t
|
| 829 |
+
|
| 830 |
+
return mapping
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
@dataclasses.dataclass(frozen=True)
|
| 834 |
+
class PrefilterResult:
|
| 835 |
+
valid: bool
|
| 836 |
+
error: Optional[str]
|
| 837 |
+
warning: Optional[str]
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
@dataclasses.dataclass(frozen=True)
|
| 841 |
+
class SingleHitResult:
|
| 842 |
+
features: Optional[Mapping[str, Any]]
|
| 843 |
+
error: Optional[str]
|
| 844 |
+
warning: Optional[str]
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def _prefilter_hit(
|
| 848 |
+
query_sequence: str,
|
| 849 |
+
hit: parsers.TemplateHit,
|
| 850 |
+
max_template_date: datetime.datetime,
|
| 851 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 852 |
+
obsolete_pdbs: Mapping[str, str],
|
| 853 |
+
strict_error_check: bool = False,
|
| 854 |
+
):
|
| 855 |
+
# Fail hard if we can't get the PDB ID and chain name from the hit.
|
| 856 |
+
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
|
| 857 |
+
|
| 858 |
+
if hit_pdb_code not in release_dates:
|
| 859 |
+
if hit_pdb_code in obsolete_pdbs:
|
| 860 |
+
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
|
| 861 |
+
|
| 862 |
+
# Pass hit_pdb_code since it might have changed due to the pdb being
|
| 863 |
+
# obsolete.
|
| 864 |
+
try:
|
| 865 |
+
_assess_hhsearch_hit(
|
| 866 |
+
hit=hit,
|
| 867 |
+
hit_pdb_code=hit_pdb_code,
|
| 868 |
+
query_sequence=query_sequence,
|
| 869 |
+
release_dates=release_dates,
|
| 870 |
+
release_date_cutoff=max_template_date,
|
| 871 |
+
)
|
| 872 |
+
except PrefilterError as e:
|
| 873 |
+
hit_name = f"{hit_pdb_code}_{hit_chain_id}"
|
| 874 |
+
msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
|
| 875 |
+
logging.info(msg)
|
| 876 |
+
if strict_error_check and isinstance(e, (DateError, DuplicateError)):
|
| 877 |
+
# In strict mode we treat some prefilter cases as errors.
|
| 878 |
+
return PrefilterResult(valid=False, error=msg, warning=None)
|
| 879 |
+
|
| 880 |
+
return PrefilterResult(valid=False, error=None, warning=None)
|
| 881 |
+
|
| 882 |
+
return PrefilterResult(valid=True, error=None, warning=None)
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
@functools.lru_cache(16, typed=False)
|
| 886 |
+
def _read_file(path):
|
| 887 |
+
with open(path, 'r') as f:
|
| 888 |
+
file_data = f.read()
|
| 889 |
+
|
| 890 |
+
return file_data
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def _process_single_hit(
|
| 894 |
+
query_sequence: str,
|
| 895 |
+
hit: parsers.TemplateHit,
|
| 896 |
+
mmcif_dir: str,
|
| 897 |
+
max_template_date: datetime.datetime,
|
| 898 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 899 |
+
obsolete_pdbs: Mapping[str, str],
|
| 900 |
+
kalign_binary_path: str,
|
| 901 |
+
strict_error_check: bool = False,
|
| 902 |
+
_zero_center_positions: bool = True,
|
| 903 |
+
) -> SingleHitResult:
|
| 904 |
+
"""Tries to extract template features from a single HHSearch hit."""
|
| 905 |
+
# Fail hard if we can't get the PDB ID and chain name from the hit.
|
| 906 |
+
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
|
| 907 |
+
|
| 908 |
+
if hit_pdb_code not in release_dates:
|
| 909 |
+
if hit_pdb_code in obsolete_pdbs:
|
| 910 |
+
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
|
| 911 |
+
|
| 912 |
+
mapping = _build_query_to_hit_index_mapping(
|
| 913 |
+
hit.query,
|
| 914 |
+
hit.hit_sequence,
|
| 915 |
+
hit.indices_hit,
|
| 916 |
+
hit.indices_query,
|
| 917 |
+
query_sequence,
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
# The mapping is from the query to the actual hit sequence, so we need to
|
| 921 |
+
# remove gaps (which regardless have a missing confidence score).
|
| 922 |
+
template_sequence = hit.hit_sequence.replace("-", "")
|
| 923 |
+
|
| 924 |
+
cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
|
| 925 |
+
logging.info(
|
| 926 |
+
"Reading PDB entry from %s. Query: %s, template: %s",
|
| 927 |
+
cif_path,
|
| 928 |
+
query_sequence,
|
| 929 |
+
template_sequence,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
# Fail if we can't find the mmCIF file.
|
| 933 |
+
cif_string = _read_file(cif_path)
|
| 934 |
+
|
| 935 |
+
parsing_result = mmcif_parsing.parse(
|
| 936 |
+
file_id=hit_pdb_code, mmcif_string=cif_string
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
if parsing_result.mmcif_object is not None:
|
| 940 |
+
hit_release_date = datetime.datetime.strptime(
|
| 941 |
+
parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
|
| 942 |
+
)
|
| 943 |
+
if hit_release_date > max_template_date:
|
| 944 |
+
error = "Template %s date (%s) > max template date (%s)." % (
|
| 945 |
+
hit_pdb_code,
|
| 946 |
+
hit_release_date,
|
| 947 |
+
max_template_date,
|
| 948 |
+
)
|
| 949 |
+
if strict_error_check:
|
| 950 |
+
return SingleHitResult(features=None, error=error, warning=None)
|
| 951 |
+
else:
|
| 952 |
+
logging.info(error)
|
| 953 |
+
return SingleHitResult(features=None, error=None, warning=None)
|
| 954 |
+
|
| 955 |
+
try:
|
| 956 |
+
features, realign_warning = _extract_template_features(
|
| 957 |
+
mmcif_object=parsing_result.mmcif_object,
|
| 958 |
+
pdb_id=hit_pdb_code,
|
| 959 |
+
mapping=mapping,
|
| 960 |
+
template_sequence=template_sequence,
|
| 961 |
+
query_sequence=query_sequence,
|
| 962 |
+
template_chain_id=hit_chain_id,
|
| 963 |
+
kalign_binary_path=kalign_binary_path,
|
| 964 |
+
_zero_center_positions=_zero_center_positions,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
if hit.sum_probs is None:
|
| 968 |
+
features["template_sum_probs"] = [0]
|
| 969 |
+
else:
|
| 970 |
+
features["template_sum_probs"] = [hit.sum_probs]
|
| 971 |
+
|
| 972 |
+
# It is possible there were some errors when parsing the other chains in the
|
| 973 |
+
# mmCIF file, but the template features for the chain we want were still
|
| 974 |
+
# computed. In such case the mmCIF parsing errors are not relevant.
|
| 975 |
+
return SingleHitResult(
|
| 976 |
+
features=features, error=None, warning=realign_warning
|
| 977 |
+
)
|
| 978 |
+
except (
|
| 979 |
+
NoChainsError,
|
| 980 |
+
NoAtomDataInTemplateError,
|
| 981 |
+
TemplateAtomMaskAllZerosError,
|
| 982 |
+
) as e:
|
| 983 |
+
# These 3 errors indicate missing mmCIF experimental data rather than a
|
| 984 |
+
# problem with the template search, so turn them into warnings.
|
| 985 |
+
warning = (
|
| 986 |
+
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
|
| 987 |
+
"%s, mmCIF parsing errors: %s"
|
| 988 |
+
% (
|
| 989 |
+
hit_pdb_code,
|
| 990 |
+
hit_chain_id,
|
| 991 |
+
hit.sum_probs if hit.sum_probs else 0.,
|
| 992 |
+
hit.index,
|
| 993 |
+
str(e),
|
| 994 |
+
parsing_result.errors,
|
| 995 |
+
)
|
| 996 |
+
)
|
| 997 |
+
if strict_error_check:
|
| 998 |
+
return SingleHitResult(features=None, error=warning, warning=None)
|
| 999 |
+
else:
|
| 1000 |
+
return SingleHitResult(features=None, error=None, warning=warning)
|
| 1001 |
+
except Error as e:
|
| 1002 |
+
error = (
|
| 1003 |
+
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
|
| 1004 |
+
"%s, mmCIF parsing errors: %s"
|
| 1005 |
+
% (
|
| 1006 |
+
hit_pdb_code,
|
| 1007 |
+
hit_chain_id,
|
| 1008 |
+
hit.sum_probs if hit.sum_probs else 0.,
|
| 1009 |
+
hit.index,
|
| 1010 |
+
str(e),
|
| 1011 |
+
parsing_result.errors,
|
| 1012 |
+
)
|
| 1013 |
+
)
|
| 1014 |
+
return SingleHitResult(features=None, error=error, warning=None)
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
def get_custom_template_features(
|
| 1018 |
+
mmcif_path: str,
|
| 1019 |
+
query_sequence: str,
|
| 1020 |
+
pdb_id: str,
|
| 1021 |
+
chain_id: str,
|
| 1022 |
+
kalign_binary_path: str):
|
| 1023 |
+
with open(mmcif_path, "r") as mmcif_path:
|
| 1024 |
+
cif_string = mmcif_path.read()
|
| 1025 |
+
|
| 1026 |
+
mmcif_parse_result = mmcif_parsing.parse(
|
| 1027 |
+
file_id=pdb_id, mmcif_string=cif_string
|
| 1028 |
+
)
|
| 1029 |
+
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
|
| 1030 |
+
|
| 1031 |
+
mapping = {x: x for x, _ in enumerate(query_sequence)}
|
| 1032 |
+
|
| 1033 |
+
features, warnings = _extract_template_features(
|
| 1034 |
+
mmcif_object=mmcif_parse_result.mmcif_object,
|
| 1035 |
+
pdb_id=pdb_id,
|
| 1036 |
+
mapping=mapping,
|
| 1037 |
+
template_sequence=template_sequence,
|
| 1038 |
+
query_sequence=query_sequence,
|
| 1039 |
+
template_chain_id=chain_id,
|
| 1040 |
+
kalign_binary_path=kalign_binary_path,
|
| 1041 |
+
_zero_center_positions=True
|
| 1042 |
+
)
|
| 1043 |
+
features["template_sum_probs"] = [1.0]
|
| 1044 |
+
|
| 1045 |
+
# TODO: clean up this logic
|
| 1046 |
+
template_features = {}
|
| 1047 |
+
for template_feature_name in TEMPLATE_FEATURES:
|
| 1048 |
+
template_features[template_feature_name] = []
|
| 1049 |
+
|
| 1050 |
+
for k in template_features:
|
| 1051 |
+
template_features[k].append(features[k])
|
| 1052 |
+
|
| 1053 |
+
for name in template_features:
|
| 1054 |
+
template_features[name] = np.stack(
|
| 1055 |
+
template_features[name], axis=0
|
| 1056 |
+
).astype(TEMPLATE_FEATURES[name])
|
| 1057 |
+
|
| 1058 |
+
return TemplateSearchResult(
|
| 1059 |
+
features=template_features, errors=None, warnings=warnings
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
@dataclasses.dataclass(frozen=True)
|
| 1064 |
+
class TemplateSearchResult:
|
| 1065 |
+
features: Mapping[str, Any]
|
| 1066 |
+
errors: Sequence[str]
|
| 1067 |
+
warnings: Sequence[str]
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
class TemplateHitFeaturizer(abc.ABC):
|
| 1071 |
+
"""An abstract base class for turning template hits to features."""
|
| 1072 |
+
|
| 1073 |
+
def __init__(
|
| 1074 |
+
self,
|
| 1075 |
+
mmcif_dir: str,
|
| 1076 |
+
max_template_date: str,
|
| 1077 |
+
max_hits: int,
|
| 1078 |
+
kalign_binary_path: str,
|
| 1079 |
+
release_dates_path: Optional[str] = None,
|
| 1080 |
+
obsolete_pdbs_path: Optional[str] = None,
|
| 1081 |
+
strict_error_check: bool = False,
|
| 1082 |
+
_shuffle_top_k_prefiltered: Optional[int] = None,
|
| 1083 |
+
_zero_center_positions: bool = True,
|
| 1084 |
+
):
|
| 1085 |
+
"""Initializes the Template Search.
|
| 1086 |
+
|
| 1087 |
+
Args:
|
| 1088 |
+
mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
|
| 1089 |
+
is found by HHSearch, this directory is used to retrieve the template
|
| 1090 |
+
data.
|
| 1091 |
+
max_template_date: The maximum date permitted for template structures. No
|
| 1092 |
+
template with date higher than this date will be returned. In ISO8601
|
| 1093 |
+
date format, YYYY-MM-DD.
|
| 1094 |
+
max_hits: The maximum number of templates that will be returned.
|
| 1095 |
+
kalign_binary_path: The path to a kalign executable used for template
|
| 1096 |
+
realignment.
|
| 1097 |
+
release_dates_path: An optional path to a file with a mapping from PDB IDs
|
| 1098 |
+
to their release dates. Thanks to this we don't have to redundantly
|
| 1099 |
+
parse mmCIF files to get that information.
|
| 1100 |
+
obsolete_pdbs_path: An optional path to a file containing a mapping from
|
| 1101 |
+
obsolete PDB IDs to the PDB IDs of their replacements.
|
| 1102 |
+
strict_error_check: If True, then the following will be treated as errors:
|
| 1103 |
+
* If any template date is after the max_template_date.
|
| 1104 |
+
* If any template has identical PDB ID to the query.
|
| 1105 |
+
* If any template is a duplicate of the query.
|
| 1106 |
+
* Any feature computation errors.
|
| 1107 |
+
"""
|
| 1108 |
+
self._mmcif_dir = mmcif_dir
|
| 1109 |
+
if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
|
| 1110 |
+
logging.error("Could not find CIFs in %s", self._mmcif_dir)
|
| 1111 |
+
raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
|
| 1112 |
+
|
| 1113 |
+
try:
|
| 1114 |
+
self._max_template_date = datetime.datetime.strptime(
|
| 1115 |
+
max_template_date, "%Y-%m-%d"
|
| 1116 |
+
)
|
| 1117 |
+
except ValueError:
|
| 1118 |
+
raise ValueError(
|
| 1119 |
+
"max_template_date must be set and have format YYYY-MM-DD."
|
| 1120 |
+
)
|
| 1121 |
+
self._max_hits = max_hits
|
| 1122 |
+
self._kalign_binary_path = kalign_binary_path
|
| 1123 |
+
self._strict_error_check = strict_error_check
|
| 1124 |
+
|
| 1125 |
+
if release_dates_path:
|
| 1126 |
+
logging.info(
|
| 1127 |
+
"Using precomputed release dates %s.", release_dates_path
|
| 1128 |
+
)
|
| 1129 |
+
self._release_dates = _parse_release_dates(release_dates_path)
|
| 1130 |
+
else:
|
| 1131 |
+
self._release_dates = {}
|
| 1132 |
+
|
| 1133 |
+
if obsolete_pdbs_path:
|
| 1134 |
+
logging.info(
|
| 1135 |
+
"Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
|
| 1136 |
+
)
|
| 1137 |
+
self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
|
| 1138 |
+
else:
|
| 1139 |
+
self._obsolete_pdbs = {}
|
| 1140 |
+
|
| 1141 |
+
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
|
| 1142 |
+
self._zero_center_positions = _zero_center_positions
|
| 1143 |
+
|
| 1144 |
+
@abc.abstractmethod
|
| 1145 |
+
def get_templates(
|
| 1146 |
+
self,
|
| 1147 |
+
query_sequence: str,
|
| 1148 |
+
hits: Sequence[parsers.TemplateHit]
|
| 1149 |
+
) -> TemplateSearchResult:
|
| 1150 |
+
""" Computes the templates for a given query sequence """
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
class HhsearchHitFeaturizer(TemplateHitFeaturizer):
|
| 1154 |
+
def get_templates(
|
| 1155 |
+
self,
|
| 1156 |
+
query_sequence: str,
|
| 1157 |
+
hits: Sequence[parsers.TemplateHit],
|
| 1158 |
+
) -> TemplateSearchResult:
|
| 1159 |
+
"""Computes the templates for given query sequence (more details above)."""
|
| 1160 |
+
logging.info("Searching for template for: %s", query_sequence)
|
| 1161 |
+
|
| 1162 |
+
template_features = {}
|
| 1163 |
+
for template_feature_name in TEMPLATE_FEATURES:
|
| 1164 |
+
template_features[template_feature_name] = []
|
| 1165 |
+
|
| 1166 |
+
already_seen = set()
|
| 1167 |
+
errors = []
|
| 1168 |
+
warnings = []
|
| 1169 |
+
|
| 1170 |
+
filtered = []
|
| 1171 |
+
for hit in hits:
|
| 1172 |
+
prefilter_result = _prefilter_hit(
|
| 1173 |
+
query_sequence=query_sequence,
|
| 1174 |
+
hit=hit,
|
| 1175 |
+
max_template_date=self._max_template_date,
|
| 1176 |
+
release_dates=self._release_dates,
|
| 1177 |
+
obsolete_pdbs=self._obsolete_pdbs,
|
| 1178 |
+
strict_error_check=self._strict_error_check,
|
| 1179 |
+
)
|
| 1180 |
+
|
| 1181 |
+
if prefilter_result.error:
|
| 1182 |
+
errors.append(prefilter_result.error)
|
| 1183 |
+
|
| 1184 |
+
if prefilter_result.warning:
|
| 1185 |
+
warnings.append(prefilter_result.warning)
|
| 1186 |
+
|
| 1187 |
+
if prefilter_result.valid:
|
| 1188 |
+
filtered.append(hit)
|
| 1189 |
+
|
| 1190 |
+
filtered = list(
|
| 1191 |
+
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
idx = list(range(len(filtered)))
|
| 1195 |
+
if (self._shuffle_top_k_prefiltered):
|
| 1196 |
+
stk = self._shuffle_top_k_prefiltered
|
| 1197 |
+
idx[:stk] = np.random.permutation(idx[:stk])
|
| 1198 |
+
|
| 1199 |
+
for i in idx:
|
| 1200 |
+
# We got all the templates we wanted, stop processing hits.
|
| 1201 |
+
if len(already_seen) >= self._max_hits:
|
| 1202 |
+
break
|
| 1203 |
+
try:
|
| 1204 |
+
hit = filtered[i]
|
| 1205 |
+
|
| 1206 |
+
result = _process_single_hit(
|
| 1207 |
+
query_sequence=query_sequence,
|
| 1208 |
+
hit=hit,
|
| 1209 |
+
mmcif_dir=self._mmcif_dir,
|
| 1210 |
+
max_template_date=self._max_template_date,
|
| 1211 |
+
release_dates=self._release_dates,
|
| 1212 |
+
obsolete_pdbs=self._obsolete_pdbs,
|
| 1213 |
+
strict_error_check=self._strict_error_check,
|
| 1214 |
+
kalign_binary_path=self._kalign_binary_path,
|
| 1215 |
+
_zero_center_positions=self._zero_center_positions,
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
if result.error:
|
| 1219 |
+
errors.append(result.error)
|
| 1220 |
+
|
| 1221 |
+
# There could be an error even if there are some results, e.g. thrown by
|
| 1222 |
+
# other unparsable chains in the same mmCIF file.
|
| 1223 |
+
if result.warning:
|
| 1224 |
+
warnings.append(result.warning)
|
| 1225 |
+
|
| 1226 |
+
if result.features is None:
|
| 1227 |
+
logging.info(
|
| 1228 |
+
"Skipped invalid hit %s, error: %s, warning: %s",
|
| 1229 |
+
hit.name,
|
| 1230 |
+
result.error,
|
| 1231 |
+
result.warning,
|
| 1232 |
+
)
|
| 1233 |
+
else:
|
| 1234 |
+
already_seen_key = result.features["template_sequence"]
|
| 1235 |
+
if (already_seen_key in already_seen):
|
| 1236 |
+
continue
|
| 1237 |
+
already_seen.add(already_seen_key)
|
| 1238 |
+
for k in template_features:
|
| 1239 |
+
template_features[k].append(result.features[k])
|
| 1240 |
+
except Exception as e:
|
| 1241 |
+
print(e)
|
| 1242 |
+
continue
|
| 1243 |
+
|
| 1244 |
+
if already_seen:
|
| 1245 |
+
for name in template_features:
|
| 1246 |
+
template_features[name] = np.stack(
|
| 1247 |
+
template_features[name], axis=0
|
| 1248 |
+
).astype(TEMPLATE_FEATURES[name])
|
| 1249 |
+
else:
|
| 1250 |
+
num_res = len(query_sequence)
|
| 1251 |
+
# Construct a default template with all zeros.
|
| 1252 |
+
template_features = empty_template_feats(num_res)
|
| 1253 |
+
|
| 1254 |
+
return TemplateSearchResult(
|
| 1255 |
+
features=template_features, errors=errors, warnings=warnings
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
|
| 1259 |
+
class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
|
| 1260 |
+
def get_templates(
|
| 1261 |
+
self,
|
| 1262 |
+
query_sequence: str,
|
| 1263 |
+
hits: Sequence[parsers.TemplateHit]
|
| 1264 |
+
) -> TemplateSearchResult:
|
| 1265 |
+
logging.info("Searching for template for: %s", query_sequence)
|
| 1266 |
+
|
| 1267 |
+
template_features = {}
|
| 1268 |
+
for template_feature_name in TEMPLATE_FEATURES:
|
| 1269 |
+
template_features[template_feature_name] = []
|
| 1270 |
+
|
| 1271 |
+
already_seen = set()
|
| 1272 |
+
errors = []
|
| 1273 |
+
warnings = []
|
| 1274 |
+
|
| 1275 |
+
# DISCREPANCY: This filtering scheme that saves time
|
| 1276 |
+
filtered = []
|
| 1277 |
+
for hit in hits:
|
| 1278 |
+
prefilter_result = _prefilter_hit(
|
| 1279 |
+
query_sequence=query_sequence,
|
| 1280 |
+
hit=hit,
|
| 1281 |
+
max_template_date=self._max_template_date,
|
| 1282 |
+
release_dates=self._release_dates,
|
| 1283 |
+
obsolete_pdbs=self._obsolete_pdbs,
|
| 1284 |
+
strict_error_check=self._strict_error_check,
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
if prefilter_result.error:
|
| 1288 |
+
errors.append(prefilter_result.error)
|
| 1289 |
+
|
| 1290 |
+
if prefilter_result.warning:
|
| 1291 |
+
warnings.append(prefilter_result.warning)
|
| 1292 |
+
|
| 1293 |
+
if prefilter_result.valid:
|
| 1294 |
+
filtered.append(hit)
|
| 1295 |
+
|
| 1296 |
+
filtered = list(
|
| 1297 |
+
sorted(
|
| 1298 |
+
filtered, key=lambda x: x.sum_probs if x.sum_probs else 0., reverse=True
|
| 1299 |
+
)
|
| 1300 |
+
)
|
| 1301 |
+
idx = list(range(len(filtered)))
|
| 1302 |
+
if (self._shuffle_top_k_prefiltered):
|
| 1303 |
+
stk = self._shuffle_top_k_prefiltered
|
| 1304 |
+
idx[:stk] = np.random.permutation(idx[:stk])
|
| 1305 |
+
|
| 1306 |
+
for i in idx:
|
| 1307 |
+
if (len(already_seen) >= self._max_hits):
|
| 1308 |
+
break
|
| 1309 |
+
|
| 1310 |
+
hit = filtered[i]
|
| 1311 |
+
|
| 1312 |
+
result = _process_single_hit(
|
| 1313 |
+
query_sequence=query_sequence,
|
| 1314 |
+
hit=hit,
|
| 1315 |
+
mmcif_dir=self._mmcif_dir,
|
| 1316 |
+
max_template_date=self._max_template_date,
|
| 1317 |
+
release_dates=self._release_dates,
|
| 1318 |
+
obsolete_pdbs=self._obsolete_pdbs,
|
| 1319 |
+
strict_error_check=self._strict_error_check,
|
| 1320 |
+
kalign_binary_path=self._kalign_binary_path
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
if result.error:
|
| 1324 |
+
errors.append(result.error)
|
| 1325 |
+
|
| 1326 |
+
if result.warning:
|
| 1327 |
+
warnings.append(result.warning)
|
| 1328 |
+
|
| 1329 |
+
if result.features is None:
|
| 1330 |
+
logging.debug(
|
| 1331 |
+
"Skipped invalid hit %s, error: %s, warning: %s",
|
| 1332 |
+
hit.name, result.error, result.warning,
|
| 1333 |
+
)
|
| 1334 |
+
else:
|
| 1335 |
+
already_seen_key = result.features["template_sequence"]
|
| 1336 |
+
if (already_seen_key in already_seen):
|
| 1337 |
+
continue
|
| 1338 |
+
# Increment the hit counter, since we got features out of this hit.
|
| 1339 |
+
already_seen.add(already_seen_key)
|
| 1340 |
+
for k in template_features:
|
| 1341 |
+
template_features[k].append(result.features[k])
|
| 1342 |
+
|
| 1343 |
+
if already_seen:
|
| 1344 |
+
for name in template_features:
|
| 1345 |
+
template_features[name] = np.stack(
|
| 1346 |
+
template_features[name], axis=0
|
| 1347 |
+
).astype(TEMPLATE_FEATURES[name])
|
| 1348 |
+
else:
|
| 1349 |
+
num_res = len(query_sequence)
|
| 1350 |
+
# Construct a default template with all zeros.
|
| 1351 |
+
template_features = empty_template_feats(num_res)
|
| 1352 |
+
|
| 1353 |
+
return TemplateSearchResult(
|
| 1354 |
+
features=template_features,
|
| 1355 |
+
errors=errors,
|
| 1356 |
+
warnings=warnings,
|
| 1357 |
+
)
|
PhysDock/data/tools/utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Common utilities for data pipeline tools."""
|
| 17 |
+
import contextlib
|
| 18 |
+
import datetime
|
| 19 |
+
import logging
|
| 20 |
+
import shutil
|
| 21 |
+
import tempfile
|
| 22 |
+
import time
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@contextlib.contextmanager
|
| 27 |
+
def tmpdir_manager(base_dir: Optional[str] = None):
|
| 28 |
+
"""Context manager that deletes a temporary directory on exit."""
|
| 29 |
+
tmpdir = tempfile.mkdtemp(dir=base_dir)
|
| 30 |
+
try:
|
| 31 |
+
yield tmpdir
|
| 32 |
+
finally:
|
| 33 |
+
shutil.rmtree(tmpdir, ignore_errors=True)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@contextlib.contextmanager
|
| 37 |
+
def timing(msg: str):
|
| 38 |
+
logging.info("Started %s", msg)
|
| 39 |
+
tic = time.perf_counter()
|
| 40 |
+
yield
|
| 41 |
+
toc = time.perf_counter()
|
| 42 |
+
logging.info("Finished %s in %.3f seconds", msg, toc - tic)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def to_date(s: str):
|
| 46 |
+
return datetime.datetime(
|
| 47 |
+
year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
|
| 48 |
+
)
|
PhysDock/models/__init__.py
ADDED
|
File without changes
|
PhysDock/models/layers/__init__.py
ADDED
|
File without changes
|