The model weights provided here are for solubility prediction
Install FragNet
FragNet is available on GitHub: https://github.com/pnnl/FragNet
To install FragNet, run the following commands:
git clone https://github.com/pnnl/FragNet.git
cd FragNet
# make sure a python virtual environment is activated
pip install --upgrade pip
pip install -r requirements.txt
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
pip install .
Load model
import torch
from huggingface_hub import hf_hub_download
from fragnet.model.gat.gat2 import FragNetFineTune
from huggingface.fragnet_config import FragNetConfig
config_path = hf_hub_download(repo_id="gihan12/FragNet", filename="config.json")
model_path = hf_hub_download(repo_id="gihan12/FragNet", filename="pytorch_model.bin")
config = FragNetConfig.from_json_file(config_path)
model = FragNetFineTune(**config.get_model_kwargs())
model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
model.eval()
Prepare molecule data the proper way
import pandas as pd
import pickle
from fragnet.dataset.data import CreateData
from fragnet.dataset.fragments import get_3Dcoords2
# A function to process SMILES
def smiles_to_fragnet_data(smiles, data_type="exp1s", frag_type="murcko"):
"""Convert SMILES to FragNet data format."""
create_data = CreateData(
data_type=data_type,
create_bond_graph_data=True,
add_dhangles=True,
)
# Get 3D coordinates
res = get_3Dcoords2(smiles, maxiters=500)
if res is None:
return None
mol, conf_res = res
# get_3Dcoords2 returns (mol, list of (conf_id, energy))
# We need to get the conformer with the lowest energy
if not conf_res:
return None
# Sort by energy and get the best conformer
conf_res_sorted = sorted(conf_res, key=lambda x: x[1])
best_conf_id = conf_res_sorted[0][0]
best_conf = mol.GetConformer(best_conf_id)
# create_data_point expects: (smiles, y, mol, conf, frag_type)
# For inference, use a dummy y value (0.0) - it will be replaced by prediction
args = (smiles, 0.0, mol, best_conf, frag_type)
data = create_data.create_data_point(args)
# Fix y to be 1D tensor for proper batching
data.y = data.y.reshape(-1)
return data
# Test with Aspirin
smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
data = smiles_to_fragnet_data(smiles)
if data is not None:
print("โ Data created successfully")
print(f" Atoms: {data.x_atoms.shape}")
print(f" Fragments: {data.x_frags.shape}")
else:
print("โ Failed to create data")
Make prediction
from fragnet.dataset.data import collate_fn
if data is not None:
# Create batch using the proper collate function
batch = collate_fn([data])
# Predict
with torch.no_grad():
prediction = model(batch)
print(f"\nPrediction for {smiles}")
print(f" Value: {prediction.item():.4f}")
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support