--- license: bsd-2-clause --- The model weights provided here are for solubility prediction # Install FragNet FragNet is available on GitHub: https://github.com/pnnl/FragNet To install FragNet, run the following commands: ```bash git clone https://github.com/pnnl/FragNet.git cd FragNet # make sure a python virtual environment is activated pip install --upgrade pip pip install -r requirements.txt pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html pip install . ``` # Load model ```python import torch from huggingface_hub import hf_hub_download from fragnet.model.gat.gat2 import FragNetFineTune from huggingface.fragnet_config import FragNetConfig config_path = hf_hub_download(repo_id="gihan12/FragNet", filename="config.json") model_path = hf_hub_download(repo_id="gihan12/FragNet", filename="pytorch_model.bin") config = FragNetConfig.from_json_file(config_path) model = FragNetFineTune(**config.get_model_kwargs()) model.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) model.eval() ``` # Prepare molecule data the proper way ```python import pandas as pd import pickle from fragnet.dataset.data import CreateData from fragnet.dataset.fragments import get_3Dcoords2 # A function to process SMILES def smiles_to_fragnet_data(smiles, data_type="exp1s", frag_type="murcko"): """Convert SMILES to FragNet data format.""" create_data = CreateData( data_type=data_type, create_bond_graph_data=True, add_dhangles=True, ) # Get 3D coordinates res = get_3Dcoords2(smiles, maxiters=500) if res is None: return None mol, conf_res = res # get_3Dcoords2 returns (mol, list of (conf_id, energy)) # We need to get the conformer with the lowest energy if not conf_res: return None # Sort by energy and get the best conformer conf_res_sorted = sorted(conf_res, key=lambda x: x[1]) best_conf_id = conf_res_sorted[0][0] best_conf = mol.GetConformer(best_conf_id) # create_data_point expects: (smiles, y, mol, conf, frag_type) # For inference, use a dummy y value (0.0) - it will be replaced by prediction args = (smiles, 0.0, mol, best_conf, frag_type) data = create_data.create_data_point(args) # Fix y to be 1D tensor for proper batching data.y = data.y.reshape(-1) return data ``` ```python # Test with Aspirin smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" data = smiles_to_fragnet_data(smiles) if data is not None: print("✓ Data created successfully") print(f" Atoms: {data.x_atoms.shape}") print(f" Fragments: {data.x_frags.shape}") else: print("✗ Failed to create data") ``` # Make prediction ```python from fragnet.dataset.data import collate_fn if data is not None: # Create batch using the proper collate function batch = collate_fn([data]) # Predict with torch.no_grad(): prediction = model(batch) print(f"\nPrediction for {smiles}") print(f" Value: {prediction.item():.4f}") ```