lamthuy commited on
Commit
4947946
·
verified ·
1 Parent(s): 00a075b

Upload folder using huggingface_hub

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/SelfiesGen.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="31">
8
+ <item index="0" class="java.lang.String" itemvalue="deepspeed" />
9
+ <item index="1" class="java.lang.String" itemvalue="tqdm" />
10
+ <item index="2" class="java.lang.String" itemvalue="gensim" />
11
+ <item index="3" class="java.lang.String" itemvalue="transformers" />
12
+ <item index="4" class="java.lang.String" itemvalue="spacy" />
13
+ <item index="5" class="java.lang.String" itemvalue="scikit-learn" />
14
+ <item index="6" class="java.lang.String" itemvalue="seqeval" />
15
+ <item index="7" class="java.lang.String" itemvalue="torch" />
16
+ <item index="8" class="java.lang.String" itemvalue="datasets" />
17
+ <item index="9" class="java.lang.String" itemvalue="argparse" />
18
+ <item index="10" class="java.lang.String" itemvalue="biopython" />
19
+ <item index="11" class="java.lang.String" itemvalue="docarray" />
20
+ <item index="12" class="java.lang.String" itemvalue="deepsearch-toolkit" />
21
+ <item index="13" class="java.lang.String" itemvalue="bidict" />
22
+ <item index="14" class="java.lang.String" itemvalue="torch-scatter" />
23
+ <item index="15" class="java.lang.String" itemvalue="torch-sparse" />
24
+ <item index="16" class="java.lang.String" itemvalue="torchvision" />
25
+ <item index="17" class="java.lang.String" itemvalue="torch-geometric" />
26
+ <item index="18" class="java.lang.String" itemvalue="torchaudio" />
27
+ <item index="19" class="java.lang.String" itemvalue="pytest-runner" />
28
+ <item index="20" class="java.lang.String" itemvalue="pytest-cov" />
29
+ <item index="21" class="java.lang.String" itemvalue="pytorch-fast_transformers" />
30
+ <item index="22" class="java.lang.String" itemvalue="docarraygraph" />
31
+ <item index="23" class="java.lang.String" itemvalue="pandas" />
32
+ <item index="24" class="java.lang.String" itemvalue="rdt" />
33
+ <item index="25" class="java.lang.String" itemvalue="typer" />
34
+ <item index="26" class="java.lang.String" itemvalue="matplotlib" />
35
+ <item index="27" class="java.lang.String" itemvalue="accelerate" />
36
+ <item index="28" class="java.lang.String" itemvalue="numpy" />
37
+ <item index="29" class="java.lang.String" itemvalue="sdmetrics" />
38
+ <item index="30" class="java.lang.String" itemvalue="optuna" />
39
+ </list>
40
+ </value>
41
+ </option>
42
+ </inspection_tool>
43
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
44
+ <option name="ignoredErrors">
45
+ <list>
46
+ <option value="N801" />
47
+ </list>
48
+ </option>
49
+ </inspection_tool>
50
+ </profile>
51
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/SelfiesGen.iml" filepath="$PROJECT_DIR$/.idea/SelfiesGen.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="5f79c3e0-b184-484e-b0eb-1fc38b1d51bc" name="Changes" comment="" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="ProjectId" id="30Abq9BchQnHSDCvGYY3Tt9OU29" />
14
+ <component name="ProjectViewState">
15
+ <option name="hideEmptyMiddlePackages" value="true" />
16
+ <option name="showLibraryContents" value="true" />
17
+ </component>
18
+ <component name="PropertiesComponent"><![CDATA[{
19
+ "keyToString": {
20
+ "RunOnceActivity.OpenProjectViewOnStart": "true",
21
+ "RunOnceActivity.ShowReadmeOnStart": "true",
22
+ "settings.editor.selected.configurable": "reference.idesettings.debugger.python"
23
+ }
24
+ }]]></component>
25
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
26
+ <component name="TaskManager">
27
+ <task active="true" id="Default" summary="Default task">
28
+ <changelist id="5f79c3e0-b184-484e-b0eb-1fc38b1d51bc" name="Changes" comment="" />
29
+ <created>1753073512240</created>
30
+ <option name="number" value="Default" />
31
+ <option name="presentableId" value="Default" />
32
+ <updated>1753073512240</updated>
33
+ </task>
34
+ <servers />
35
+ </component>
36
+ </project>
.ipynb_checkpoints/Examples-checkpoint.ipynb ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "a1af2321-8860-4a3e-8406-a9ae587b97bf",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
11
+ "import selfies as sf\n",
12
+ "from rdkit import Chem\n",
13
+ "from typing import Optional\n",
14
+ "import numpy as np\n",
15
+ "import py3Dmol\n",
16
+ "from rdkit import Chem, DataStructs\n",
17
+ "from rdkit.Chem import AllChem\n",
18
+ "import torch\n",
19
+ "\n",
20
+ "def smiles_to_3d(smiles_list, width=400, height=300):\n",
21
+ " # Visualize the 3D structure using py3Dmol\n",
22
+ " view = py3Dmol.view(width=width, height=height)\n",
23
+ " for smiles in smiles_list:\n",
24
+ " # Generate the RDKit molecule object\n",
25
+ " mol = Chem.MolFromSmiles(smiles)\n",
26
+ " if mol is None:\n",
27
+ " raise ValueError(\"Invalid SMILES string\")\n",
28
+ "\n",
29
+ " # Add hydrogens to the molecule\n",
30
+ " mol = Chem.AddHs(mol)\n",
31
+ "\n",
32
+ " # Generate 3D coordinates\n",
33
+ " AllChem.EmbedMolecule(mol, randomSeed=42)\n",
34
+ " AllChem.UFFOptimizeMolecule(mol)\n",
35
+ "\n",
36
+ " # Generate the 3D structure in the form of a pdb string\n",
37
+ " pdb = Chem.MolToPDBBlock(mol)\n",
38
+ " view.addModel(pdb, 'pdb')\n",
39
+ " view.setStyle({'stick': {}})\n",
40
+ " view.zoomTo()\n",
41
+ " return view\n",
42
+ "\n",
43
+ " \n",
44
+ "# Load the checkpoint and the tokenizer\n",
45
+ "checkpoint_path = \"lamthuy/SelfiesGen\"\n",
46
+ "model = AutoModelForCausalLM.from_pretrained(checkpoint_path)\n",
47
+ "tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 9,
53
+ "id": "9b7066a4-6637-4d45-a0d9-3cc5e2ca0409",
54
+ "metadata": {},
55
+ "outputs": [
56
+ {
57
+ "data": {
58
+ "application/3dmoljs_load.v0": "<div id=\"3dmolviewer_1753263966683991\" style=\"position: relative; width: 400px; height: 300px;\">\n <p id=\"3dmolwarning_1753263966683991\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n </div>\n<script>\n\nvar loadScriptAsync = function(uri){\n return new Promise((resolve, reject) => {\n //this is to ignore the existence of requirejs amd\n var savedexports, savedmodule;\n if (typeof exports !== 'undefined') savedexports = exports;\n else exports = {}\n if (typeof module !== 'undefined') savedmodule = module;\n else module = {}\n\n var tag = document.createElement('script');\n tag.src = uri;\n tag.async = true;\n tag.onload = () => {\n exports = savedexports;\n module = savedmodule;\n resolve();\n };\n var firstScriptTag = document.getElementsByTagName('script')[0];\n firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n});\n};\n\nif(typeof $3Dmolpromise === 'undefined') {\n$3Dmolpromise = null;\n $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js');\n}\n\nvar viewer_1753263966683991 = null;\nvar warn = document.getElementById(\"3dmolwarning_1753263966683991\");\nif(warn) {\n warn.parentNode.removeChild(warn);\n}\n$3Dmolpromise.then(function() {\nviewer_1753263966683991 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_1753263966683991\"),{backgroundColor:\"white\"});\nviewer_1753263966683991.zoomTo();\n\tviewer_1753263966683991.addModel(\"HETATM 1 C1 UNL 1 -2.975 -2.217 0.190 1.00 0.00 C \\nHETATM 2 C2 UNL 1 -2.096 -1.150 -0.368 1.00 0.00 C \\nHETATM 3 O1 UNL 1 -2.392 -0.615 -1.471 1.00 0.00 O \\nHETATM 4 O2 UNL 1 -0.898 -0.839 0.288 1.00 0.00 O \\nHETATM 5 C3 UNL 1 -0.095 0.284 0.025 1.00 0.00 C \\nHETATM 6 C4 UNL 1 -0.671 1.501 -0.391 1.00 0.00 C \\nHETATM 7 C5 UNL 1 0.129 2.617 -0.638 1.00 0.00 C \\nHETATM 8 C6 UNL 1 1.509 2.540 -0.463 1.00 0.00 C \\nHETATM 9 C7 UNL 1 2.094 1.348 -0.032 1.00 0.00 C \\nHETATM 10 C8 UNL 1 1.307 0.209 0.223 1.00 0.00 C \\nHETATM 11 C9 UNL 1 1.972 -1.038 0.682 1.00 0.00 C \\nHETATM 12 O3 UNL 1 1.307 -2.087 0.898 1.00 0.00 O \\nHETATM 13 O4 UNL 1 3.351 -1.062 0.878 1.00 0.00 O \\nHETATM 14 H1 UNL 1 -4.038 -1.995 -0.040 1.00 0.00 H \\nHETATM 15 H2 UNL 1 -2.850 -2.271 1.291 1.00 0.00 H \\nHETATM 16 H3 UNL 1 -2.699 -3.195 -0.257 1.00 0.00 H \\nHETATM 17 H4 UNL 1 -1.742 1.599 -0.498 1.00 0.00 H \\nHETATM 18 H5 UNL 1 -0.324 3.547 -0.957 1.00 0.00 H \\nHETATM 19 H6 UNL 1 2.126 3.408 -0.653 1.00 0.00 H \\nHETATM 20 H7 UNL 1 3.168 1.322 0.102 1.00 0.00 H \\nHETATM 21 H8 UNL 1 3.818 -1.905 1.190 1.00 0.00 H \\nCONECT 1 2 14 15 16\\nCONECT 2 3 3 4\\nCONECT 4 5\\nCONECT 5 6 6 10\\nCONECT 6 7 17\\nCONECT 7 8 8 18\\nCONECT 8 9 19\\nCONECT 9 10 10 20\\nCONECT 10 11\\nCONECT 11 12 12 13\\nCONECT 13 21\\nEND\\n\",\"pdb\");\n\tviewer_1753263966683991.setStyle({\"stick\": {}});\n\tviewer_1753263966683991.zoomTo();\nviewer_1753263966683991.render();\n});\n</script>",
59
+ "text/html": [
60
+ "<div id=\"3dmolviewer_1753263966683991\" style=\"position: relative; width: 400px; height: 300px;\">\n",
61
+ " <p id=\"3dmolwarning_1753263966683991\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n",
62
+ " </div>\n",
63
+ "<script>\n",
64
+ "\n",
65
+ "var loadScriptAsync = function(uri){\n",
66
+ " return new Promise((resolve, reject) => {\n",
67
+ " //this is to ignore the existence of requirejs amd\n",
68
+ " var savedexports, savedmodule;\n",
69
+ " if (typeof exports !== 'undefined') savedexports = exports;\n",
70
+ " else exports = {}\n",
71
+ " if (typeof module !== 'undefined') savedmodule = module;\n",
72
+ " else module = {}\n",
73
+ "\n",
74
+ " var tag = document.createElement('script');\n",
75
+ " tag.src = uri;\n",
76
+ " tag.async = true;\n",
77
+ " tag.onload = () => {\n",
78
+ " exports = savedexports;\n",
79
+ " module = savedmodule;\n",
80
+ " resolve();\n",
81
+ " };\n",
82
+ " var firstScriptTag = document.getElementsByTagName('script')[0];\n",
83
+ " firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n",
84
+ "});\n",
85
+ "};\n",
86
+ "\n",
87
+ "if(typeof $3Dmolpromise === 'undefined') {\n",
88
+ "$3Dmolpromise = null;\n",
89
+ " $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js');\n",
90
+ "}\n",
91
+ "\n",
92
+ "var viewer_1753263966683991 = null;\n",
93
+ "var warn = document.getElementById(\"3dmolwarning_1753263966683991\");\n",
94
+ "if(warn) {\n",
95
+ " warn.parentNode.removeChild(warn);\n",
96
+ "}\n",
97
+ "$3Dmolpromise.then(function() {\n",
98
+ "viewer_1753263966683991 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_1753263966683991\"),{backgroundColor:\"white\"});\n",
99
+ "viewer_1753263966683991.zoomTo();\n",
100
+ "\tviewer_1753263966683991.addModel(\"HETATM 1 C1 UNL 1 -2.975 -2.217 0.190 1.00 0.00 C \\nHETATM 2 C2 UNL 1 -2.096 -1.150 -0.368 1.00 0.00 C \\nHETATM 3 O1 UNL 1 -2.392 -0.615 -1.471 1.00 0.00 O \\nHETATM 4 O2 UNL 1 -0.898 -0.839 0.288 1.00 0.00 O \\nHETATM 5 C3 UNL 1 -0.095 0.284 0.025 1.00 0.00 C \\nHETATM 6 C4 UNL 1 -0.671 1.501 -0.391 1.00 0.00 C \\nHETATM 7 C5 UNL 1 0.129 2.617 -0.638 1.00 0.00 C \\nHETATM 8 C6 UNL 1 1.509 2.540 -0.463 1.00 0.00 C \\nHETATM 9 C7 UNL 1 2.094 1.348 -0.032 1.00 0.00 C \\nHETATM 10 C8 UNL 1 1.307 0.209 0.223 1.00 0.00 C \\nHETATM 11 C9 UNL 1 1.972 -1.038 0.682 1.00 0.00 C \\nHETATM 12 O3 UNL 1 1.307 -2.087 0.898 1.00 0.00 O \\nHETATM 13 O4 UNL 1 3.351 -1.062 0.878 1.00 0.00 O \\nHETATM 14 H1 UNL 1 -4.038 -1.995 -0.040 1.00 0.00 H \\nHETATM 15 H2 UNL 1 -2.850 -2.271 1.291 1.00 0.00 H \\nHETATM 16 H3 UNL 1 -2.699 -3.195 -0.257 1.00 0.00 H \\nHETATM 17 H4 UNL 1 -1.742 1.599 -0.498 1.00 0.00 H \\nHETATM 18 H5 UNL 1 -0.324 3.547 -0.957 1.00 0.00 H \\nHETATM 19 H6 UNL 1 2.126 3.408 -0.653 1.00 0.00 H \\nHETATM 20 H7 UNL 1 3.168 1.322 0.102 1.00 0.00 H \\nHETATM 21 H8 UNL 1 3.818 -1.905 1.190 1.00 0.00 H \\nCONECT 1 2 14 15 16\\nCONECT 2 3 3 4\\nCONECT 4 5\\nCONECT 5 6 6 10\\nCONECT 6 7 17\\nCONECT 7 8 8 18\\nCONECT 8 9 19\\nCONECT 9 10 10 20\\nCONECT 10 11\\nCONECT 11 12 12 13\\nCONECT 13 21\\nEND\\n\",\"pdb\");\n",
101
+ "\tviewer_1753263966683991.setStyle({\"stick\": {}});\n",
102
+ "\tviewer_1753263966683991.zoomTo();\n",
103
+ "viewer_1753263966683991.render();\n",
104
+ "});\n",
105
+ "</script>"
106
+ ]
107
+ },
108
+ "metadata": {},
109
+ "output_type": "display_data"
110
+ },
111
+ {
112
+ "data": {
113
+ "text/plain": [
114
+ "<py3Dmol.view at 0x2afd09c40>"
115
+ ]
116
+ },
117
+ "execution_count": 9,
118
+ "metadata": {},
119
+ "output_type": "execute_result"
120
+ }
121
+ ],
122
+ "source": [
123
+ "# Given a SMILES, get its fingerpint\n",
124
+ "smiles = \"CC(=O)OC1=CC=CC=C1C(=O)O\"\n",
125
+ "smiles_to_3d([smiles])"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 12,
131
+ "id": "05f9bf21-c998-4d63-870e-c1033ff91b31",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "name": "stderr",
136
+ "output_type": "stream",
137
+ "text": [
138
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
139
+ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
140
+ ]
141
+ },
142
+ {
143
+ "name": "stdout",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "[C][C][=Branch1][C][=O][O][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=Branch1][C][=S][O][SEP]\n",
147
+ "[C][C][=Branch1][C][=O][O][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=Branch1][C][=NH2+1][O]\n",
148
+ "CC(=O)OC1=CC=CC=C1C(=[NH2+1])O\n"
149
+ ]
150
+ },
151
+ {
152
+ "data": {
153
+ "application/3dmoljs_load.v0": "<div id=\"3dmolviewer_1753264053659028\" style=\"position: relative; width: 400px; height: 300px;\">\n <p id=\"3dmolwarning_1753264053659028\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n </div>\n<script>\n\nvar loadScriptAsync = function(uri){\n return new Promise((resolve, reject) => {\n //this is to ignore the existence of requirejs amd\n var savedexports, savedmodule;\n if (typeof exports !== 'undefined') savedexports = exports;\n else exports = {}\n if (typeof module !== 'undefined') savedmodule = module;\n else module = {}\n\n var tag = document.createElement('script');\n tag.src = uri;\n tag.async = true;\n tag.onload = () => {\n exports = savedexports;\n module = savedmodule;\n resolve();\n };\n var firstScriptTag = document.getElementsByTagName('script')[0];\n firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n});\n};\n\nif(typeof $3Dmolpromise === 'undefined') {\n$3Dmolpromise = null;\n $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js');\n}\n\nvar viewer_1753264053659028 = null;\nvar warn = document.getElementById(\"3dmolwarning_1753264053659028\");\nif(warn) {\n warn.parentNode.removeChild(warn);\n}\n$3Dmolpromise.then(function() {\nviewer_1753264053659028 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_1753264053659028\"),{backgroundColor:\"white\"});\nviewer_1753264053659028.zoomTo();\n\tviewer_1753264053659028.addModel(\"HETATM 1 C1 UNL 1 2.411 -2.582 -1.327 1.00 0.00 C \\nHETATM 2 C2 UNL 1 1.777 -1.233 -1.311 1.00 0.00 C \\nHETATM 3 O1 UNL 1 1.971 -0.444 -2.275 1.00 0.00 O \\nHETATM 4 O2 UNL 1 0.887 -0.908 -0.278 1.00 0.00 O \\nHETATM 5 C3 UNL 1 0.406 0.380 0.005 1.00 0.00 C \\nHETATM 6 C4 UNL 1 1.187 1.525 -0.250 1.00 0.00 C \\nHETATM 7 C5 UNL 1 0.691 2.798 0.033 1.00 0.00 C \\nHETATM 8 C6 UNL 1 -0.580 2.947 0.586 1.00 0.00 C \\nHETATM 9 C7 UNL 1 -1.357 1.822 0.866 1.00 0.00 C \\nHETATM 10 C8 UNL 1 -0.876 0.529 0.582 1.00 0.00 C \\nHETATM 11 C9 UNL 1 -1.702 -0.662 0.905 1.00 0.00 C \\nHETATM 12 N1 UNL 1 -2.996 -0.640 0.776 1.00 0.00 N1+\\nHETATM 13 O3 UNL 1 -1.073 -1.826 1.349 1.00 0.00 O \\nHETATM 14 H1 UNL 1 1.780 -3.281 -1.914 1.00 0.00 H \\nHETATM 15 H2 UNL 1 2.512 -2.964 -0.290 1.00 0.00 H \\nHETATM 16 H3 UNL 1 3.421 -2.524 -1.787 1.00 0.00 H \\nHETATM 17 H4 UNL 1 2.190 1.435 -0.643 1.00 0.00 H \\nHETATM 18 H5 UNL 1 1.299 3.672 -0.166 1.00 0.00 H \\nHETATM 19 H6 UNL 1 -0.957 3.936 0.813 1.00 0.00 H \\nHETATM 20 H7 UNL 1 -2.325 1.965 1.329 1.00 0.00 H \\nHETATM 21 H8 UNL 1 -3.567 -1.482 1.010 1.00 0.00 H \\nHETATM 22 H9 UNL 1 -3.499 0.198 0.408 1.00 0.00 H \\nHETATM 23 H10 UNL 1 -1.599 -2.662 1.576 1.00 0.00 H \\nCONECT 1 2 14 15 16\\nCONECT 2 3 3 4\\nCONECT 4 5\\nCONECT 5 6 6 10\\nCONECT 6 7 17\\nCONECT 7 8 8 18\\nCONECT 8 9 19\\nCONECT 9 10 10 20\\nCONECT 10 11\\nCONECT 11 12 12 13\\nCONECT 12 21 22\\nCONECT 13 23\\nEND\\n\",\"pdb\");\n\tviewer_1753264053659028.setStyle({\"stick\": {}});\n\tviewer_1753264053659028.zoomTo();\nviewer_1753264053659028.render();\n});\n</script>",
154
+ "text/html": [
155
+ "<div id=\"3dmolviewer_1753264053659028\" style=\"position: relative; width: 400px; height: 300px;\">\n",
156
+ " <p id=\"3dmolwarning_1753264053659028\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n",
157
+ " </div>\n",
158
+ "<script>\n",
159
+ "\n",
160
+ "var loadScriptAsync = function(uri){\n",
161
+ " return new Promise((resolve, reject) => {\n",
162
+ " //this is to ignore the existence of requirejs amd\n",
163
+ " var savedexports, savedmodule;\n",
164
+ " if (typeof exports !== 'undefined') savedexports = exports;\n",
165
+ " else exports = {}\n",
166
+ " if (typeof module !== 'undefined') savedmodule = module;\n",
167
+ " else module = {}\n",
168
+ "\n",
169
+ " var tag = document.createElement('script');\n",
170
+ " tag.src = uri;\n",
171
+ " tag.async = true;\n",
172
+ " tag.onload = () => {\n",
173
+ " exports = savedexports;\n",
174
+ " module = savedmodule;\n",
175
+ " resolve();\n",
176
+ " };\n",
177
+ " var firstScriptTag = document.getElementsByTagName('script')[0];\n",
178
+ " firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n",
179
+ "});\n",
180
+ "};\n",
181
+ "\n",
182
+ "if(typeof $3Dmolpromise === 'undefined') {\n",
183
+ "$3Dmolpromise = null;\n",
184
+ " $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js');\n",
185
+ "}\n",
186
+ "\n",
187
+ "var viewer_1753264053659028 = null;\n",
188
+ "var warn = document.getElementById(\"3dmolwarning_1753264053659028\");\n",
189
+ "if(warn) {\n",
190
+ " warn.parentNode.removeChild(warn);\n",
191
+ "}\n",
192
+ "$3Dmolpromise.then(function() {\n",
193
+ "viewer_1753264053659028 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_1753264053659028\"),{backgroundColor:\"white\"});\n",
194
+ "viewer_1753264053659028.zoomTo();\n",
195
+ "\tviewer_1753264053659028.addModel(\"HETATM 1 C1 UNL 1 2.411 -2.582 -1.327 1.00 0.00 C \\nHETATM 2 C2 UNL 1 1.777 -1.233 -1.311 1.00 0.00 C \\nHETATM 3 O1 UNL 1 1.971 -0.444 -2.275 1.00 0.00 O \\nHETATM 4 O2 UNL 1 0.887 -0.908 -0.278 1.00 0.00 O \\nHETATM 5 C3 UNL 1 0.406 0.380 0.005 1.00 0.00 C \\nHETATM 6 C4 UNL 1 1.187 1.525 -0.250 1.00 0.00 C \\nHETATM 7 C5 UNL 1 0.691 2.798 0.033 1.00 0.00 C \\nHETATM 8 C6 UNL 1 -0.580 2.947 0.586 1.00 0.00 C \\nHETATM 9 C7 UNL 1 -1.357 1.822 0.866 1.00 0.00 C \\nHETATM 10 C8 UNL 1 -0.876 0.529 0.582 1.00 0.00 C \\nHETATM 11 C9 UNL 1 -1.702 -0.662 0.905 1.00 0.00 C \\nHETATM 12 N1 UNL 1 -2.996 -0.640 0.776 1.00 0.00 N1+\\nHETATM 13 O3 UNL 1 -1.073 -1.826 1.349 1.00 0.00 O \\nHETATM 14 H1 UNL 1 1.780 -3.281 -1.914 1.00 0.00 H \\nHETATM 15 H2 UNL 1 2.512 -2.964 -0.290 1.00 0.00 H \\nHETATM 16 H3 UNL 1 3.421 -2.524 -1.787 1.00 0.00 H \\nHETATM 17 H4 UNL 1 2.190 1.435 -0.643 1.00 0.00 H \\nHETATM 18 H5 UNL 1 1.299 3.672 -0.166 1.00 0.00 H \\nHETATM 19 H6 UNL 1 -0.957 3.936 0.813 1.00 0.00 H \\nHETATM 20 H7 UNL 1 -2.325 1.965 1.329 1.00 0.00 H \\nHETATM 21 H8 UNL 1 -3.567 -1.482 1.010 1.00 0.00 H \\nHETATM 22 H9 UNL 1 -3.499 0.198 0.408 1.00 0.00 H \\nHETATM 23 H10 UNL 1 -1.599 -2.662 1.576 1.00 0.00 H \\nCONECT 1 2 14 15 16\\nCONECT 2 3 3 4\\nCONECT 4 5\\nCONECT 5 6 6 10\\nCONECT 6 7 17\\nCONECT 7 8 8 18\\nCONECT 8 9 19\\nCONECT 9 10 10 20\\nCONECT 10 11\\nCONECT 11 12 12 13\\nCONECT 12 21 22\\nCONECT 13 23\\nEND\\n\",\"pdb\");\n",
196
+ "\tviewer_1753264053659028.setStyle({\"stick\": {}});\n",
197
+ "\tviewer_1753264053659028.zoomTo();\n",
198
+ "viewer_1753264053659028.render();\n",
199
+ "});\n",
200
+ "</script>"
201
+ ]
202
+ },
203
+ "metadata": {},
204
+ "output_type": "display_data"
205
+ },
206
+ {
207
+ "data": {
208
+ "text/plain": [
209
+ "<py3Dmol.view at 0x2afbaf020>"
210
+ ]
211
+ },
212
+ "execution_count": 12,
213
+ "metadata": {},
214
+ "output_type": "execute_result"
215
+ }
216
+ ],
217
+ "source": [
218
+ "s = sf.encoder(smiles)\n",
219
+ "s = s + \"[SEP]\"\n",
220
+ "print(s)\n",
221
+ "input_ids = tokenizer.encode(s, return_tensors=\"pt\")\n",
222
+ "n = input_ids.size(1)\n",
223
+ "# Generate output sequence\n",
224
+ "output_ids = model.generate(input_ids, max_length=128, num_beams=5, num_return_sequences=5,\n",
225
+ " early_stopping=True)\n",
226
+ "output = tokenizer.decode(output_ids[1][n:], skip_special_tokens=True)\n",
227
+ "print(output)\n",
228
+ "smiles = sf.decoder(output)\n",
229
+ "print(smiles)\n",
230
+ "smiles_to_3d([smiles])"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 13,
236
+ "id": "dbe6cebd-c7d5-4da9-aac0-114f232cf147",
237
+ "metadata": {},
238
+ "outputs": [
239
+ {
240
+ "name": "stderr",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
244
+ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
245
+ ]
246
+ },
247
+ {
248
+ "name": "stdout",
249
+ "output_type": "stream",
250
+ "text": [
251
+ "[C][C][=Branch1][C][=O][N][C][=C][C][=N][C][=C][Ring1][=Branch1][C][=Branch1][C][=S][O-1]\n",
252
+ "CC(=O)NC1=CC=NC=C1C(=S)[O-1]\n"
253
+ ]
254
+ },
255
+ {
256
+ "data": {
257
+ "application/3dmoljs_load.v0": "<div id=\"3dmolviewer_17532640697518232\" style=\"position: relative; width: 400px; height: 300px;\">\n <p id=\"3dmolwarning_17532640697518232\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n </div>\n<script>\n\nvar loadScriptAsync = function(uri){\n return new Promise((resolve, reject) => {\n //this is to ignore the existence of requirejs amd\n var savedexports, savedmodule;\n if (typeof exports !== 'undefined') savedexports = exports;\n else exports = {}\n if (typeof module !== 'undefined') savedmodule = module;\n else module = {}\n\n var tag = document.createElement('script');\n tag.src = uri;\n tag.async = true;\n tag.onload = () => {\n exports = savedexports;\n module = savedmodule;\n resolve();\n };\n var firstScriptTag = document.getElementsByTagName('script')[0];\n firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n});\n};\n\nif(typeof $3Dmolpromise === 'undefined') {\n$3Dmolpromise = null;\n $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js');\n}\n\nvar viewer_17532640697518232 = null;\nvar warn = document.getElementById(\"3dmolwarning_17532640697518232\");\nif(warn) {\n warn.parentNode.removeChild(warn);\n}\n$3Dmolpromise.then(function() {\nviewer_17532640697518232 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_17532640697518232\"),{backgroundColor:\"white\"});\nviewer_17532640697518232.zoomTo();\n\tviewer_17532640697518232.addModel(\"HETATM 1 C1 UNL 1 3.341 0.506 -0.349 1.00 0.00 C \\nHETATM 2 C2 UNL 1 1.994 -0.109 -0.554 1.00 0.00 C \\nHETATM 3 O1 UNL 1 1.839 -0.974 -1.458 1.00 0.00 O \\nHETATM 4 N1 UNL 1 0.876 0.346 0.221 1.00 0.00 N \\nHETATM 5 C3 UNL 1 -0.424 -0.266 0.167 1.00 0.00 C \\nHETATM 6 C4 UNL 1 -0.537 -1.665 0.219 1.00 0.00 C \\nHETATM 7 C5 UNL 1 -1.798 -2.258 0.278 1.00 0.00 C \\nHETATM 8 N2 UNL 1 -2.913 -1.485 0.302 1.00 0.00 N \\nHETATM 9 C6 UNL 1 -2.844 -0.130 0.260 1.00 0.00 C \\nHETATM 10 C7 UNL 1 -1.602 0.521 0.177 1.00 0.00 C \\nHETATM 11 C8 UNL 1 -1.585 1.998 0.085 1.00 0.00 C \\nHETATM 12 S1 UNL 1 -0.600 2.783 -0.963 1.00 0.00 S \\nHETATM 13 O2 UNL 1 -2.485 2.742 0.838 1.00 0.00 O1-\\nHETATM 14 H1 UNL 1 3.531 0.642 0.736 1.00 0.00 H \\nHETATM 15 H2 UNL 1 4.133 -0.148 -0.770 1.00 0.00 H \\nHETATM 16 H3 UNL 1 3.376 1.493 -0.854 1.00 0.00 H \\nHETATM 17 H4 UNL 1 1.005 1.188 0.826 1.00 0.00 H \\nHETATM 18 H5 UNL 1 0.344 -2.292 0.249 1.00 0.00 H \\nHETATM 19 H6 UNL 1 -1.889 -3.335 0.323 1.00 0.00 H \\nHETATM 20 H7 UNL 1 -3.764 0.442 0.270 1.00 0.00 H \\nCONECT 1 2 14 15 16\\nCONECT 2 3 3 4\\nCONECT 4 5 17\\nCONECT 5 6 6 10\\nCONECT 6 7 18\\nCONECT 7 8 8 19\\nCONECT 8 9\\nCONECT 9 10 10 20\\nCONECT 10 11\\nCONECT 11 12 12 13\\nEND\\n\",\"pdb\");\n\tviewer_17532640697518232.setStyle({\"stick\": {}});\n\tviewer_17532640697518232.zoomTo();\nviewer_17532640697518232.render();\n});\n</script>",
258
+ "text/html": [
259
+ "<div id=\"3dmolviewer_17532640697518232\" style=\"position: relative; width: 400px; height: 300px;\">\n",
260
+ " <p id=\"3dmolwarning_17532640697518232\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n",
261
+ " </div>\n",
262
+ "<script>\n",
263
+ "\n",
264
+ "var loadScriptAsync = function(uri){\n",
265
+ " return new Promise((resolve, reject) => {\n",
266
+ " //this is to ignore the existence of requirejs amd\n",
267
+ " var savedexports, savedmodule;\n",
268
+ " if (typeof exports !== 'undefined') savedexports = exports;\n",
269
+ " else exports = {}\n",
270
+ " if (typeof module !== 'undefined') savedmodule = module;\n",
271
+ " else module = {}\n",
272
+ "\n",
273
+ " var tag = document.createElement('script');\n",
274
+ " tag.src = uri;\n",
275
+ " tag.async = true;\n",
276
+ " tag.onload = () => {\n",
277
+ " exports = savedexports;\n",
278
+ " module = savedmodule;\n",
279
+ " resolve();\n",
280
+ " };\n",
281
+ " var firstScriptTag = document.getElementsByTagName('script')[0];\n",
282
+ " firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n",
283
+ "});\n",
284
+ "};\n",
285
+ "\n",
286
+ "if(typeof $3Dmolpromise === 'undefined') {\n",
287
+ "$3Dmolpromise = null;\n",
288
+ " $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js');\n",
289
+ "}\n",
290
+ "\n",
291
+ "var viewer_17532640697518232 = null;\n",
292
+ "var warn = document.getElementById(\"3dmolwarning_17532640697518232\");\n",
293
+ "if(warn) {\n",
294
+ " warn.parentNode.removeChild(warn);\n",
295
+ "}\n",
296
+ "$3Dmolpromise.then(function() {\n",
297
+ "viewer_17532640697518232 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_17532640697518232\"),{backgroundColor:\"white\"});\n",
298
+ "viewer_17532640697518232.zoomTo();\n",
299
+ "\tviewer_17532640697518232.addModel(\"HETATM 1 C1 UNL 1 3.341 0.506 -0.349 1.00 0.00 C \\nHETATM 2 C2 UNL 1 1.994 -0.109 -0.554 1.00 0.00 C \\nHETATM 3 O1 UNL 1 1.839 -0.974 -1.458 1.00 0.00 O \\nHETATM 4 N1 UNL 1 0.876 0.346 0.221 1.00 0.00 N \\nHETATM 5 C3 UNL 1 -0.424 -0.266 0.167 1.00 0.00 C \\nHETATM 6 C4 UNL 1 -0.537 -1.665 0.219 1.00 0.00 C \\nHETATM 7 C5 UNL 1 -1.798 -2.258 0.278 1.00 0.00 C \\nHETATM 8 N2 UNL 1 -2.913 -1.485 0.302 1.00 0.00 N \\nHETATM 9 C6 UNL 1 -2.844 -0.130 0.260 1.00 0.00 C \\nHETATM 10 C7 UNL 1 -1.602 0.521 0.177 1.00 0.00 C \\nHETATM 11 C8 UNL 1 -1.585 1.998 0.085 1.00 0.00 C \\nHETATM 12 S1 UNL 1 -0.600 2.783 -0.963 1.00 0.00 S \\nHETATM 13 O2 UNL 1 -2.485 2.742 0.838 1.00 0.00 O1-\\nHETATM 14 H1 UNL 1 3.531 0.642 0.736 1.00 0.00 H \\nHETATM 15 H2 UNL 1 4.133 -0.148 -0.770 1.00 0.00 H \\nHETATM 16 H3 UNL 1 3.376 1.493 -0.854 1.00 0.00 H \\nHETATM 17 H4 UNL 1 1.005 1.188 0.826 1.00 0.00 H \\nHETATM 18 H5 UNL 1 0.344 -2.292 0.249 1.00 0.00 H \\nHETATM 19 H6 UNL 1 -1.889 -3.335 0.323 1.00 0.00 H \\nHETATM 20 H7 UNL 1 -3.764 0.442 0.270 1.00 0.00 H \\nCONECT 1 2 14 15 16\\nCONECT 2 3 3 4\\nCONECT 4 5 17\\nCONECT 5 6 6 10\\nCONECT 6 7 18\\nCONECT 7 8 8 19\\nCONECT 8 9\\nCONECT 9 10 10 20\\nCONECT 10 11\\nCONECT 11 12 12 13\\nEND\\n\",\"pdb\");\n",
300
+ "\tviewer_17532640697518232.setStyle({\"stick\": {}});\n",
301
+ "\tviewer_17532640697518232.zoomTo();\n",
302
+ "viewer_17532640697518232.render();\n",
303
+ "});\n",
304
+ "</script>"
305
+ ]
306
+ },
307
+ "metadata": {},
308
+ "output_type": "display_data"
309
+ },
310
+ {
311
+ "data": {
312
+ "text/plain": [
313
+ "<py3Dmol.view at 0x11fd5a780>"
314
+ ]
315
+ },
316
+ "execution_count": 13,
317
+ "metadata": {},
318
+ "output_type": "execute_result"
319
+ }
320
+ ],
321
+ "source": [
322
+ "input_ids[0][5] = tokenizer.mask_token_id\n",
323
+ "input_ids[0][9] = tokenizer.mask_token_id\n",
324
+ "input_ids[0][18] = tokenizer.mask_token_id\n",
325
+ "input_ids[0][11] = tokenizer.mask_token_id\n",
326
+ "# Generate output sequence\n",
327
+ "output_ids = model.generate(input_ids, max_length=128, num_beams=5, num_return_sequences=5,\n",
328
+ " early_stopping=True)\n",
329
+ "output = tokenizer.decode(output_ids[1][n:], skip_special_tokens=True)\n",
330
+ "print(output)\n",
331
+ "smiles = sf.decoder(output)\n",
332
+ "print(smiles)\n",
333
+ "smiles_to_3d([smiles])"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "f696bb9c-2870-4b0b-9b62-1623411e5df6",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": []
343
+ }
344
+ ],
345
+ "metadata": {
346
+ "kernelspec": {
347
+ "display_name": "Python 3 (ipykernel)",
348
+ "language": "python",
349
+ "name": "python3"
350
+ },
351
+ "language_info": {
352
+ "codemirror_mode": {
353
+ "name": "ipython",
354
+ "version": 3
355
+ },
356
+ "file_extension": ".py",
357
+ "mimetype": "text/x-python",
358
+ "name": "python",
359
+ "nbconvert_exporter": "python",
360
+ "pygments_lexer": "ipython3",
361
+ "version": "3.12.10"
362
+ }
363
+ },
364
+ "nbformat": 4,
365
+ "nbformat_minor": 5
366
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3bac4c2f49b99725ee39be55476579193e474801c27e01ecf760df051daaa3b1
3
  size 327533344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d48f249b28c222dc74e38e28f43333a0ef3b5af65de24a70fd020c66fb7cf9b
3
  size 327533344
trainer_state.json CHANGED
The diff for this file is too large to render. See raw diff