| |
| """ |
| Self-Supervised Training for Molecular Representations (SMILES) |
| |
| Usage: |
| python trainbarlow.py --config config.yaml |
| """ |
| print("Initializing ...") |
| import os |
| import json |
| import argparse |
| import random |
| from pathlib import Path |
| from typing import Dict, Any, Tuple, List |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from tqdm.auto import tqdm |
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.preprocessing import normalize |
|
|
| |
| from rdkit import RDLogger |
| RDLogger.DisableLog('rdApp.*') |
|
|
| try: |
| from rdkit.Chem import MolFromSmiles, MolToSmiles, AllChem |
| from rdkit import DataStructs |
| except ImportError: |
| raise ImportError("RDKit is required. Install with: conda install -c conda-forge rdkit") |
|
|
| try: |
| from sentence_transformers import SentenceTransformer, InputExample |
| except ImportError: |
| raise ImportError("Install sentence-transformers: pip install sentence-transformers") |
|
|
|
|
| |
| |
| |
|
|
| class BarlowTwinsProjector(nn.Module): |
| """Projector with BatchNorm (for Barlow Twins).""" |
| def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048): |
| super().__init__() |
| self.layers = nn.Sequential( |
| nn.Linear(in_dim, hidden_dim, bias=False), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, hidden_dim, bias=False), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, out_dim, bias=False), |
| nn.BatchNorm1d(out_dim, affine=False) |
| ) |
|
|
| def forward(self, x): |
| return self.layers(x) |
|
|
| |
| |
| |
|
|
| class BarlowTwinsLoss(nn.Module): |
| """ |
| Barlow Twins' Loss Implementation |
| with shared standardization and scaled off-diagonals with d. |
| """ |
| def __init__(self, λ: float = 0.005): |
| super().__init__() |
| self.λ = λ |
|
|
| def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]: |
| B, d = z1.shape |
| |
| z = torch.cat([z1, z2], dim=0) |
| z = (z - z.mean(dim=0)) / (z.std(dim=0) + 1e-8) |
| z1, z2 = z[:B], z[B:] |
| c = (z1.T @ z2) / B |
| on_diag = (1 - torch.diagonal(c)).pow(2).sum() |
| off_diag = (c ** 2).sum() - torch.diagonal(c).pow(2).sum() |
| off_diag = off_diag / d |
| total_loss = on_diag + self.λ * off_diag |
| with torch.no_grad(): |
| diag_mean = torch.diagonal(c).mean().item() |
| off_diag_mask = ~torch.eye(d, dtype=torch.bool, device=c.device) |
| off_diag_mean = c[off_diag_mask].abs().mean().item() |
| return total_loss, { |
| 'od': on_diag.item(), |
| 'ofsc': (self.λ * off_diag).item(), |
| 'ofrw': off_diag.item(), |
| 'cr_onm': diag_mean, |
| 'cr_offm': off_diag_mean |
| } |
|
|
|
|
| |
| |
| |
|
|
| def load_config(config_path: str) -> Dict[str, Any]: |
| config_path = Path(config_path) |
| if config_path.suffix in {'.yaml', '.yml'}: |
| import yaml |
| with open(config_path) as f: |
| return yaml.safe_load(f) |
| elif config_path.suffix == '.json': |
| with open(config_path) as f: |
| return json.load(f) |
| else: |
| raise ValueError(f"Unsupported config format: {config_path.suffix}") |
|
|
| def sanitize_config(config: Dict[str, Any]) -> Dict[str, Any]: |
| float_keys = { |
| "LR", "WEIGHT_DECAY", "BARLOW_LAMBDA", "VICREG_LAMBDA", |
| "VICREG_MU", "VICREG_NU", "CORINFOMAX_ALPHA" |
| } |
| int_keys = { |
| "BATCH_SIZE", "EFFECTIVE_BATCH", "EPOCHS", "MAX_LENGTH", |
| "SEED", "EVAL_EVERY_N_PERCENT" |
| } |
| bool_keys = {"BEST_BY_HEALTH"} |
| for key in float_keys: |
| if key in config: |
| config[key] = float(config[key]) |
| for key in int_keys: |
| if key in config: |
| config[key] = int(config[key]) |
| for key in bool_keys: |
| if key in config: |
| val = config[key] |
| config[key] = val.lower() in {"true", "1", "yes", "on"} if isinstance(val, str) else bool(val) |
| return config |
|
|
| def set_seed(seed: int): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| def enum_smiles(smi: str, k: int = 2) -> List[str]: |
| from rdkit.Chem import MolFromSmiles, MolToSmiles |
| mol = MolFromSmiles(smi) |
| if mol is None: |
| return [smi] * k |
| variants = set() |
| attempts = 0 |
| while len(variants) < k and attempts < 100: |
| variants.add(MolToSmiles(mol, doRandom=True, canonical=False)) |
| attempts += 1 |
| return list(variants)[:k] |
|
|
| def tanimoto(s1: str, s2: str) -> float: |
| m1, m2 = MolFromSmiles(s1), MolFromSmiles(s2) |
| if not m1 or not m2: |
| return 0.0 |
| fp1 = AllChem.GetMorganFingerprintAsBitVect(m1, radius=2, nBits=2048) |
| fp2 = AllChem.GetMorganFingerprintAsBitVect(m2, radius=2, nBits=2048) |
| return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
| def uniformity_metrics(emb: np.ndarray) -> Dict[str, float]: |
| emb = normalize(emb) |
| sim = cosine_similarity(emb) |
| mask = ~np.eye(len(sim), dtype=bool) |
| pairwise = sim[mask] |
| mean_sim, std_sim = pairwise.mean(), pairwise.std() |
| distances = 1 - sim |
| uniformity = np.log(np.exp(-2 * distances[mask]).mean()) |
| return { |
| 'mean': float(mean_sim), |
| 'std': float(std_sim), |
| 'uniformity': float(uniformity), |
| 'health_old': float(1 - mean_sim), |
| 'collapsed': mean_sim > 0.7 or std_sim < 0.05 |
| } |
|
|
| def forward_pooled(model: SentenceTransformer, text_list: List[str], device: torch.device) -> torch.Tensor: |
| tok = model.tokenize(text_list) |
| tok = {k: v.to(device) for k, v in tok.items()} |
| hf_output = model(tok) |
| return hf_output['token_embeddings'][:, 0, :] |
|
|
| def evaluate(model, eval_smiles: List[str], device: torch.device, step: int) -> Dict[str, Any]: |
| model.eval() |
| with torch.no_grad(): |
| emb = model.encode(eval_smiles, convert_to_numpy=True, show_progress_bar=False, batch_size=32) |
| um = uniformity_metrics(emb) |
| same_view = [enum_smiles(s, 1)[0] for s in eval_smiles] |
| with torch.no_grad(): |
| emb2 = model.encode(same_view, convert_to_numpy=True, show_progress_bar=False, batch_size=32) |
| same_cos = np.diag(cosine_similarity(emb, emb2)) |
| alignment = 1 - same_cos.mean() |
| barlow_health = same_cos.mean() - um['mean'] |
| print(f"\n📊 Step {step} | Alignment={alignment:.3f} | Uniformity={um['uniformity']:.3f}") |
| print(f" Same-mol cos: {same_cos.mean():.3f}±{same_cos.std():.3f} | Pairwise: {um['mean']:.3f}±{um['std']:.3f}") |
| print(f" Barlow Health: {barlow_health:.3f} (higher = better)") |
| model.train() |
| um['health'] = barlow_health |
| um['alignment'] = alignment |
| um['same_cos_mean'] = same_cos.mean() |
| um['same_cos_std'] = same_cos.std() |
| return um |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--epochs", type=int) |
| parser.add_argument("--lr", type=float) |
| parser.add_argument("--batch_size", type=int) |
| parser.add_argument("--loss_type", type=str, choices=["barlow", "vicreg", "corinfomax"]) |
| args = parser.parse_args() |
|
|
| config = load_config(args.config) |
| for key, value in vars(args).items(): |
| if value is not None and key != "config": |
| config[key] = value |
| config = sanitize_config(config) |
|
|
| set_seed(config.get("SEED", 42)) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| output_dir = Path(config["OUTPUT_DIR"]) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| df = pd.read_csv(config["DATA_PATH"]) |
| smiles_list = df["SMILES"].dropna().tolist() |
| print(f"📂 Loaded {len(smiles_list)} SMILES") |
|
|
| train_examples = [] |
| for smi in tqdm(smiles_list, desc="Enumerating SMILES"): |
| variants = enum_smiles(smi, 2) |
| if len(variants) < 2: |
| variants = [smi, smi] |
| train_examples.append(InputExample(texts=[variants[0], variants[1]])) |
| print(f" Created {len(train_examples)} pairs") |
|
|
| eval_size = min(200, len(smiles_list)) |
| eval_smiles = np.random.choice(smiles_list, eval_size, replace=False).tolist() |
|
|
| |
| model = SentenceTransformer('./chmbedv2-warmup-l5/final') |
| model.max_seq_length = config.get("MAX_LENGTH", 512) |
| embed_dim = model.get_sentence_embedding_dimension() |
|
|
| |
| loss_type = config.get("LOSS_TYPE", "barlow") |
| if loss_type == "barlow": |
| projector = BarlowTwinsProjector( |
| embed_dim, |
| hidden_dim=2048, |
| out_dim=2048 |
| ).to(device) |
| train_loss = BarlowTwinsLoss( |
| λ=config.get("BARLOW_LAMBDA", 0.005) |
| ).to(device) |
| else: |
| raise ValueError(f"Unknown loss_type: {loss_type}") |
|
|
| model.to(device) |
|
|
| |
| from ranger21 import Ranger21 |
|
|
| no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] |
| model_params = [ |
| {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
| "weight_decay": config.get("WEIGHT_DECAY", 0.01)}, |
| {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
| "weight_decay": 0.0} |
| ] |
|
|
| |
| batch_size = config.get("BATCH_SIZE", 8) |
| effective_batch = config.get("EFFECTIVE_BATCH", 32) |
| grad_acc = effective_batch // batch_size |
| epochs = config.get("EPOCHS", 1) |
| total_steps = (len(train_examples) // effective_batch) * epochs |
| train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x) |
| num_batches_per_epoch = len(train_examples) // effective_batch |
|
|
| no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] |
| model_params = [ |
| {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
| "weight_decay": config.get("WEIGHT_DECAY", 0.01)}, |
| {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
| "weight_decay": 0.0} |
| ] |
|
|
| optimizer = Ranger21( |
| model_params + [{"params": projector.parameters(), "weight_decay": config.get("WEIGHT_DECAY", 0.01)}], |
| lr=config.get("LR", 1e-5), |
| num_epochs=epochs, |
| num_batches_per_epoch=num_batches_per_epoch, |
| weight_decay=0.0, |
| ) |
|
|
| |
| scheduler = torch.optim.lr_scheduler.LinearLR( |
| optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps |
| ) |
|
|
|
|
| |
| model.train() |
| step = 0 |
| best_health = 0.0 |
| best_step = 0 |
| log_interval = max(1, int(total_steps * config.get("EVAL_EVERY_N_PERCENT", 25) / 100)) |
|
|
| for epoch in range(epochs): |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") |
| for batch_idx, batch in enumerate(pbar): |
| texts = [[ex.texts[i] for ex in batch] for i in range(2)] |
| z1 = forward_pooled(model, texts[0], device) |
| z2 = forward_pooled(model, texts[1], device) |
| p1 = projector(z1) |
| p2 = projector(z2) |
| loss, extras = train_loss(p1, p2) |
|
|
| loss = loss / grad_acc |
| loss.backward() |
|
|
| if (batch_idx + 1) % grad_acc == 0: |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| step += 1 |
|
|
| postfix = {"step": step, "lr": scheduler.get_last_lr()[0]} |
| for k, v in extras.items(): |
| postfix[k] = f"{v:.3f}" |
| pbar.set_postfix(postfix) |
|
|
| if step % log_interval == 0 or step == total_steps: |
| um = evaluate(model, eval_smiles, device, step) |
| if config.get("BEST_BY_HEALTH", True) and um["health"] > best_health: |
| best_health, best_step = um["health"], step |
| model.save(str(output_dir / "best")) |
|
|
| model.save(str(output_dir / "final")) |
| print(f"\n✅ Training complete! Best health: {best_health:.3f} at step {best_step}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |