The model weights provided here are for solubility prediction

Install FragNet

FragNet is available on GitHub: https://github.com/pnnl/FragNet

To install FragNet, run the following commands:

git clone https://github.com/pnnl/FragNet.git
cd FragNet
# make sure a python virtual environment is activated
pip install --upgrade pip
pip install -r requirements.txt
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
pip install .

Load model

import torch
from huggingface_hub import hf_hub_download
from fragnet.model.gat.gat2 import FragNetFineTune
from huggingface.fragnet_config import FragNetConfig

config_path = hf_hub_download(repo_id="gihan12/FragNet", filename="config.json")
model_path = hf_hub_download(repo_id="gihan12/FragNet", filename="pytorch_model.bin")

config = FragNetConfig.from_json_file(config_path)
model = FragNetFineTune(**config.get_model_kwargs())
model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
model.eval()

Prepare molecule data the proper way

import pandas as pd
import pickle
from fragnet.dataset.data import CreateData
from fragnet.dataset.fragments import get_3Dcoords2

# A function to process SMILES
def smiles_to_fragnet_data(smiles, data_type="exp1s", frag_type="murcko"):
    """Convert SMILES to FragNet data format."""
    create_data = CreateData(
        data_type=data_type,
        create_bond_graph_data=True,
        add_dhangles=True,
    )
    
    # Get 3D coordinates
    res = get_3Dcoords2(smiles, maxiters=500)
    if res is None:
        return None
    
    mol, conf_res = res
    
    # get_3Dcoords2 returns (mol, list of (conf_id, energy))
    # We need to get the conformer with the lowest energy
    if not conf_res:
        return None
    
    # Sort by energy and get the best conformer
    conf_res_sorted = sorted(conf_res, key=lambda x: x[1])
    best_conf_id = conf_res_sorted[0][0]
    best_conf = mol.GetConformer(best_conf_id)
    
    # create_data_point expects: (smiles, y, mol, conf, frag_type)
    # For inference, use a dummy y value (0.0) - it will be replaced by prediction
    args = (smiles, 0.0, mol, best_conf, frag_type)
    data = create_data.create_data_point(args)
    
    # Fix y to be 1D tensor for proper batching
    data.y = data.y.reshape(-1)
    
    return data
# Test with Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
data = smiles_to_fragnet_data(smiles)

if data is not None:
    print("โœ“ Data created successfully")
    print(f"  Atoms: {data.x_atoms.shape}")
    print(f"  Fragments: {data.x_frags.shape}")
else:
    print("โœ— Failed to create data")

Make prediction

from fragnet.dataset.data import collate_fn

if data is not None:
    # Create batch using the proper collate function
    batch = collate_fn([data])
    
    # Predict
    with torch.no_grad():
        prediction = model(batch)
        print(f"\nPrediction for {smiles}")
        print(f"  Value: {prediction.item():.4f}")
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support