| import json |
| import logging |
| import os |
| from collections import defaultdict |
| from typing import Dict, List, Tuple |
|
|
| import mols2grid |
| import pandas as pd |
| from rdkit import Chem |
| from terminator.selfies import decoder |
|
|
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
|
|
|
|
| def draw_grid_generate( |
| seeds: List[str], |
| scaffolds: List[str], |
| samples: List[str], |
| n_cols: int = 5, |
| size=(140, 200), |
| ) -> str: |
| """ |
| Uses mols2grid to draw a HTML grid for the generated molecules |
| |
| Args: |
| samples: The generated samples. |
| n_cols: Number of columns in grid. Defaults to 5. |
| size: Size of molecule in grid. Defaults to (140, 200). |
| |
| Returns: |
| HTML to display |
| """ |
|
|
| result = defaultdict(list) |
| result.update( |
| { |
| "SMILES": seeds + scaffolds + samples, |
| "Name": [f"Seed_{i}" for i in range(len(seeds))] |
| + [f"Scaffold_{i}" for i in range(len(scaffolds))] |
| + [f"Generated_{i}" for i in range(len(samples))], |
| }, |
| ) |
|
|
| result_df = pd.DataFrame(result) |
| obj = mols2grid.display( |
| result_df, |
| tooltip=list(result.keys()), |
| height=1100, |
| n_cols=n_cols, |
| name="Results", |
| size=size, |
| ) |
| return obj.data |
|
|