import matplotlib.pyplot as plt import matplotlib as mpl import numpy as np import os import pandas as pd from rdkit import Chem, DataStructs from rdkit.Chem import AllChem from rdkit.ML.Cluster import Butina from lightning.pytorch import seed_everything import torch from tqdm import tqdm from transformers import AutoModelForMaskedLM from datasets import Dataset, DatasetDict from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer seed_everything(1986) # Starting with a raw dataframe, using caco2 as example. df = pd.read_csv("caco2.csv") mols = [] canon = [] keep_rows = [] bad = 0 for i, smi in enumerate(df["SMILES"].astype(str)): m = Chem.MolFromSmiles(smi) if m is None: bad += 1 continue smi_can = Chem.MolToSmiles(m, canonical=True, isomericSmiles=True) mols.append(m) canon.append(smi_can) keep_rows.append(i) df = df.iloc[keep_rows].reset_index(drop=True) df["SMILES_CANON"] = canon print(f"Invalid SMILES dropped: {bad} / {len(df) + bad}") # Drop exact duplicate molecules (same canonical smiles) dup_mask = df.duplicated(subset=["SMILES_CANON"], keep="first") df = df.loc[~dup_mask].reset_index(drop=True) mols = [m for m, isdup in zip(mols, dup_mask) if not isdup] # Fingerprints morgan = AllChem.GetMorganGenerator(radius=2, fpSize=2048, includeChirality=True) fps = [morgan.GetFingerprint(m) for m in mols] # Cluster by similarity threshold sim_thresh = 0.6 dist_thresh = 1.0 - sim_thresh dists = [] n = len(fps) for i in range(1, n): sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]) dists.extend([1.0 - x for x in sims]) clusters = Butina.ClusterData(dists, nPts=n, distThresh=dist_thresh, isDistData=True) cluster_ids = np.empty(n, dtype=int) for cid, idxs in enumerate(clusters): for idx in idxs: cluster_ids[idx] = cid df["cluster_id"] = cluster_ids # Split by clusters train_fraction = 0.8 rng = np.random.default_rng() unique_clusters = df["cluster_id"].unique() rng.shuffle(unique_clusters) train_target = int(train_fraction * len(df)) train_clusters = set() count = 0 for cid in unique_clusters: csize = (df["cluster_id"] == cid).sum() if count + csize <= train_target: train_clusters.add(cid) count += csize df["split"] = np.where(df["cluster_id"].isin(train_clusters), "train", "val") df[df["split"] == "train"].to_csv("caco2_train.csv", index=False) df[df["split"] == "val"].to_csv("caco2_val.csv", index=False) df.to_csv("caco2_meta_with_split.csv", index=False) print(df["split"].value_counts())