| import logging |
| import pathlib |
|
|
| import gradio as gr |
| import pandas as pd |
| from gt4sd.algorithms.conditional_generation.guacamol import ( |
| AaeGenerator, |
| GraphGAGenerator, |
| GraphMCTSGenerator, |
| GuacaMolGenerator, |
| MosesGenerator, |
| OrganGenerator, |
| VaeGenerator, |
| SMILESGAGenerator, |
| SMILESLSTMHCGenerator, |
| SMILESLSTMPPOGenerator, |
| ) |
| from gt4sd.algorithms.registry import ApplicationsRegistry |
|
|
| from utils import draw_grid_generate |
|
|
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
|
|
| TITLE = "GuacaMol & MOSES" |
|
|
| CONFIG_FACTORY = { |
| "Moses - AaeGenerator": AaeGenerator, |
| "Moses - VaeGenerator": VaeGenerator, |
| "Moses - OrganGenerator": OrganGenerator, |
| "GuacaMol - GraphGAGenerator": GraphGAGenerator, |
| "GuacaMol - GraphMCTSGenerator": GraphMCTSGenerator, |
| "GuacaMol - SMILESLSTMHCGenerator": SMILESLSTMHCGenerator, |
| "GuacaMol - SMILESLSTMPPOGenerator": SMILESLSTMPPOGenerator, |
| "GuacaMol - SMILESGAGenerator": SMILESGAGenerator, |
| } |
| |
| CONFIG_FACTORY = { |
| "AaeGenerator": AaeGenerator, |
| "VaeGenerator": VaeGenerator, |
| "OrganGenerator": OrganGenerator, |
| } |
| MODEL_FACTORY = {"Moses": MosesGenerator, "GuacaMol": GuacaMolGenerator} |
|
|
|
|
| def run_inference( |
| algorithm_version: str, |
| length: int, |
| |
| |
| |
| |
| number_of_samples: int, |
| ): |
| config_class = CONFIG_FACTORY[algorithm_version] |
| |
| family = "Moses" |
| model_class = MODEL_FACTORY[family] |
|
|
| if family == "Moses": |
| kwargs = {"n_samples": number_of_samples, "max_len": length} |
| elif family == "GuacaMol": |
| kwargs = { |
| "population_size": population_size, |
| "random_start": random_start, |
| "patience": patience, |
| "generations": generations, |
| } |
| if "MCTS" in algorithm_version: |
| kwargs.pop("random_start") |
| if "LSTMHC" in algorithm_version: |
| kwargs["max_len"] = length |
| kwargs.pop("population_size") |
| kwargs.pop("patience") |
| kwargs.pop("generations") |
| if "LSTMPPO" in algorithm_version: |
| kwargs = {} |
| else: |
| raise ValueError(f"Unknown family {family}") |
|
|
| config = config_class(**kwargs) |
|
|
| model = model_class(configuration=config, target={}) |
| samples = list(model.sample(number_of_samples)) |
|
|
| return draw_grid_generate(seeds=[], samples=samples, n_cols=5) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| all_algos = ApplicationsRegistry.list_available() |
| guacamol_algos = [ |
| "GuacaMol - " + x["algorithm_application"] |
| for x in list(filter(lambda x: "GuacaMol" in x["algorithm_name"], all_algos)) |
| ] |
| moses_algos = [ |
| "Moses - " + x["algorithm_application"] |
| for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos)) |
| ] |
| algos = guacamol_algos + moses_algos |
|
|
| |
| algos = [ |
| x["algorithm_application"] |
| for x in list(filter(lambda x: "Moses" in x["algorithm_name"], all_algos)) |
| ] |
|
|
| |
| metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
| examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( |
| "" |
| ) |
|
|
| with open(metadata_root.joinpath("article.md"), "r") as f: |
| article = f.read() |
| with open(metadata_root.joinpath("description.md"), "r") as f: |
| description = f.read() |
|
|
| demo = gr.Interface( |
| fn=run_inference, |
| title="MOSES", |
| inputs=[ |
| gr.Dropdown(algos, label="Algorithm version", value="AaeGenerator"), |
| gr.Slider( |
| minimum=5, maximum=500, value=100, label="Sequence length", step=1 |
| ), |
| |
| |
| |
| |
| |
| |
| gr.Slider( |
| minimum=1, maximum=50, value=5, label="Number of samples", step=1 |
| ), |
| ], |
| outputs=gr.HTML(label="Output"), |
| article=article, |
| description=description, |
| examples=examples.values.tolist(), |
| ) |
| demo.launch(debug=True, show_error=True) |
|
|