GLP-ESM2-3B-Layer18

A Generative Latent Prior (GLP) model trained on ESM2-3B Layer 18 activations from UniRef50 protein sequences. GLP learns the distribution of protein representations via flow matching, enabling on-manifold projection during protein sequence steering.

This is the 3B counterpart of GLP-ESM2-650M-Layer17, with 2x hidden dimension (2560 vs 1280) and 4x denoiser parameters (1.34B vs 335M).

Model Details

Property Value
Base PLM ESM2-3B (esm2_t36_3B_UR50D)
Extraction Layer 18 (of 36)
Input Dimension 2560
Denoiser Architecture 6-layer SwiGLU MLP, d_model=5120, d_mlp=10240
Denoiser Parameters ~1.34B
Training Data UniRef50 (4M sequences, online extraction)
Training 1 epoch, batch_size=8192, grad_accum=2, lr=5e-5, cosine schedule
Training GPUs 4x (multi-GPU producer-consumer pipeline)
Framework Flow Matching (FlowMatchEulerDiscreteScheduler)

Evaluation Metrics

Metric Value
Loss 1.068
FID (PCA-128d) 682.5
MMD (RBF) 0.132
NLL (Hutchinson) 2292.1
gen_mean_err 0.250
gen_std_err 0.076

Note: Higher FID/NLL compared to the 650M model is expected due to the much larger activation space (2560 vs 1280 dimensions).

Files

File Description
final.safetensors Model weights (denoiser), ~5.0 GB
config.yaml Full model + training configuration
rep_statistics.pt Per-dimension mean/var for z-score normalization

Usage

Installation

git clone https://github.com/zhangshuibai/Steering-PLMs.git
cd Steering-PLMs
pip install torch esm safetensors omegaconf einops diffusers

1. Load the Model

import sys
sys.path.insert(0, "generative_latent_prior")

from glp.denoiser import load_glp

# Load from local path (after downloading)
model = load_glp("generative_latent_prior/runs/glp-esm2-3b-layer18-d6", device="cuda:0")

# Or load from HuggingFace directly
model = load_glp("Shuibai12138/glp-esm2-3b-layer18", device="cuda:0")

2. Generate Activations from Noise (Unconditional Sampling)

Sample synthetic ESM2-3B Layer 18 activations from the learned distribution:

import torch
from glp import flow_matching

# Sample from noise
noise = torch.randn(100, 1, 2560).to("cuda:0")  # 100 samples, dim=2560
gen_acts = flow_matching.sample(model, noise, num_timesteps=100)

# Denormalize back to ESM2 activation space
gen_acts = model.normalizer.denormalize(gen_acts)  # (100, 1, 2560)

3. On-Manifold Projection (SDEdit for Protein Steering)

The primary use case: after applying a steering vector to ESM2-3B activations, project the steered activations back onto the natural protein manifold to maintain sequence naturalness.

from glp import flow_matching

def project_on_manifold(model, acts, u=0.5, num_timesteps=20):
    """
    SDEdit-style projection.

    Args:
        model: loaded GLP model
        acts: (B, 1, 2560) raw ESM2-3B activations (possibly steered)
        u: noise level (0=no change, 0.5=moderate, 1.0=full reconstruction)
        num_timesteps: denoising steps

    Returns:
        (B, 1, 2560) on-manifold activations
    """
    model.scheduler.set_timesteps(num_timesteps)

    # Normalize
    latents = model.normalizer.normalize(acts)

    # Add noise to timestep u
    noise = torch.randn_like(latents)
    noisy_latents, _, timesteps, _ = flow_matching.fm_prepare(
        model.scheduler, latents, noise,
        u=torch.ones(latents.shape[0], device=latents.device) * u,
    )

    # Denoise (SDEdit)
    latents = flow_matching.sample_on_manifold(
        model, noisy_latents,
        start_timestep=timesteps[0].item(),
        num_timesteps=num_timesteps,
    )

    # Denormalize
    return model.normalizer.denormalize(latents)

4. Full Steering + GLP Pipeline

For the 650M model, a ready-to-use steering script is available:

python steering_with_glp.py \
    --glp_path generative_latent_prior/runs/glp-esm2-650m-layer17-d6 \
    --gpu_gen cuda:0 --gpu_ppl 0 1 2 3 \
    --n_gen 100 --u 0.5

To use the 3B GLP for steering, adapt steering_with_glp.py with:

  • ESM2-3B as the base model (requires ~11.5 GB GPU memory)
  • Layer 18 as the steering + projection layer
  • Input dimension 2560

Key Code Files

File Role
steering_with_glp.py Main entry: steering + GLP projection + sequence generation + evaluation
generative_latent_prior/glp/denoiser.py Model definition: Normalizer, Denoiser, GLP, load_glp()
generative_latent_prior/glp/flow_matching.py fm_prepare(), sample(), sample_on_manifold()
generative_latent_prior/glp/script_steer.py Generic steering utilities: addition_intervention(), postprocess_on_manifold_wrapper()
generative_latent_prior/glp/script_eval.py FID evaluation and PCA visualization

Training

This model was trained using the online pipeline (glp_train_online.py), specifically designed for ESM2-3B where offline activation storage would require ~4.4TB disk space. The online approach:

  1. Loads ESM2-3B (frozen) across multiple GPUs as producer workers
  2. Generates Layer 18 activations on-the-fly from UniRef50 sequences
  3. Computes per-dimension mean/var statistics online (Welford algorithm with float64 precision)
  4. Trains the GLP denoiser in a consumer process via flow matching
  5. Discards activations immediately (zero disk storage for activations)

For training details, see:

Reproducing Training

# Multi-GPU online training (GPUs 4,5,6,7)
python generative_latent_prior/glp_train_online.py \
    --config generative_latent_prior/configs/glp-esm2-3b-layer18-d6-online.yaml \
    --gpu_ids 4 5 6 7

# Evaluate existing checkpoint
python generative_latent_prior/glp_train_online.py \
    --eval_existing generative_latent_prior/runs/glp-esm2-3b-layer18-d6 \
    --esm_model_size 3B --extract_layer 18 --gpu_ids 0

How GLP Works

GLP (Generative Latent Prior) learns the distribution of ESM2 internal activations using flow matching:

Training:    Real ESM2 activations --> z-score normalize --> flow matching loss (predict velocity field)
Sampling:    Gaussian noise --> denoise with learned velocity --> denormalize --> synthetic activations
SDEdit:      Steered activation --> normalize --> add noise (level u) --> denoise --> denormalize --> on-manifold activation

The key insight: steering vectors can push activations off the natural protein manifold, degrading sequence quality. GLP's SDEdit projection pulls them back while preserving the steering direction.

Related Models

Citation

@misc{steering-plms,
  title={Steering Protein Language Models},
  author={Zhang, Shuibai},
  year={2025},
  url={https://github.com/zhangshuibai/Steering-PLMs}
}
Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support