TD3B / launch_multi_target.sh
chq1155's picture
Upload TD3B code (inference, training, baselines)
ee6da62 verified
raw
history blame
6.56 kB
#!/bin/bash
# Multi-Target TD3B Training Launch Script
# Trains TD3B on multiple protein targets with random sampling strategy
# ============================================================================
# Configuration
# ============================================================================
# Paths — update these to your local paths
BASE_PATH="/path/to/TD3B"
PRETRAINED_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
TRAIN_CSV="${BASE_PATH}/data/train.csv"
VAL_CSV="${BASE_PATH}/data/test.csv" # Optional: create validation split
# Run configuration
RUN_NAME="multi_target_td3b" # Timestamp will be added automatically
DEVICE="cuda:0"
# Multi-target sampling
TARGETS_PER_MCTS=2 # Number of targets sampled per MCTS round (K)
RESAMPLE_TARGETS_EVERY=1 # Resample targets every N epochs
# Training hyperparameters
NUM_EPOCHS=200
LEARNING_RATE=3e-4
TRAIN_BATCH_SIZE=1 # Small batch size to prevent OOM
GRADIENT_ACCUMULATION_STEPS=32 # Effective batch size = 16 * 4 = 64
RESAMPLE_EVERY=10 # Run MCTS every N epochs
SAVE_EVERY=20
VALIDATE_EVERY=20
RESET_TREE_EVERY=50
# MCTS hyperparameters (aligned with v1, but can reduce for multi-target)
NUM_ITER=20 # MCTS iterations per resample (v1 default: 50, reduced for multi-target)
NUM_CHILDREN=16 # Children per MCTS expansion
BUFFER_SIZE=50 # Pareto buffer size (v1 default: 50)
REPLAY_BUFFER_SIZE=1000 # Recommended range: 500-5000 (0 disables replay)
REPLAY_BUFFER_STRATEGY="fifo" # fifo or random
ALPHA=0.1 # Temperature for importance weighting
EXPLORATION=1.0 # UCB exploration constant
# TD3B hyperparameters (aligned with v1 defaults)
CONTRASTIVE_WEIGHT=0.1 # v1 default: 0.1
CONTRASTIVE_MARGIN=1.0
KL_BETA=0.1 # v1 default: 0.1
MIN_AFFINITY_THRESHOLD=0.0 # CRITICAL: minimum affinity for allosteric control
SIGMOID_TEMPERATURE=0.1
# Validation
VAL_SAMPLES_PER_TARGET=20 # Number of sequences per target during validation
# Directional oracle (GPCR classifier)
ORACLE_CKPT="${BASE_PATH}/checkpoints/direction_oracle.pt"
ORACLE_TR2D2_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
ORACLE_TOKENIZER_VOCAB="${BASE_PATH}/tokenizer/new_vocab.txt"
ORACLE_TOKENIZER_SPLITS="${BASE_PATH}/tokenizer/new_splits.txt"
ORACLE_ESM_NAME="facebook/esm2_t33_650M_UR50D"
ORACLE_ESM_CACHE_DIR="" # Optional: set to a cache dir path
ORACLE_ESM_LOCAL_FILES_ONLY=0 # Set to 1 to avoid network access
ORACLE_MAX_LIGAND_LENGTH=768
ORACLE_MAX_PROTEIN_LENGTH=1024
ORACLE_D_MODEL=256
ORACLE_N_HEADS=4
ORACLE_N_SELF_ATTN_LAYERS=1
ORACLE_N_BMCA_LAYERS=2
ORACLE_DROPOUT=0.3
EXTRA_ORACLE_ARGS=""
if [ -n "$ORACLE_ESM_CACHE_DIR" ]; then
EXTRA_ORACLE_ARGS="$EXTRA_ORACLE_ARGS --direction_oracle_esm_cache_dir $ORACLE_ESM_CACHE_DIR"
fi
if [ "$ORACLE_ESM_LOCAL_FILES_ONLY" -eq 1 ]; then
EXTRA_ORACLE_ARGS="$EXTRA_ORACLE_ARGS --direction_oracle_esm_local_files_only"
fi
# W&B (optional)
WANDB_PROJECT="tr2d2-multi-target"
WANDB_ENTITY="phos_zj"
# ============================================================================
# Launch Training
# ============================================================================
cd ${BASE_PATH}
echo "============================================================================"
echo "Multi-Target TD3B Training"
echo "============================================================================"
echo "Configuration:"
echo " - Targets per MCTS: ${TARGETS_PER_MCTS}"
echo " - Training batch size: ${TRAIN_BATCH_SIZE}"
echo " - Gradient accumulation: ${GRADIENT_ACCUMULATION_STEPS}"
echo " - Effective batch size: $((TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS))"
echo " - Epochs: ${NUM_EPOCHS}"
echo " - MCTS iterations: ${NUM_ITER}"
echo " - MCTS children: ${NUM_CHILDREN}"
echo " - Buffer size: ${BUFFER_SIZE}"
echo " - Replay buffer size: ${REPLAY_BUFFER_SIZE} (${REPLAY_BUFFER_STRATEGY})"
echo "============================================================================"
echo ""
# Build command
CMD="python finetune_multi_target.py \
--base_path ${BASE_PATH} \
--train_csv ${TRAIN_CSV} \
--pretrained_checkpoint ${PRETRAINED_CHECKPOINT} \
--run_name ${RUN_NAME} \
--device ${DEVICE} \
\
--targets_per_mcts ${TARGETS_PER_MCTS} \
--resample_targets_every ${RESAMPLE_TARGETS_EVERY} \
\
--num_epochs ${NUM_EPOCHS} \
--learning_rate ${LEARNING_RATE} \
--train_batch_size ${TRAIN_BATCH_SIZE} \
--gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
--resample_every_n_step ${RESAMPLE_EVERY} \
--save_every_n_epochs ${SAVE_EVERY} \
--validate_every_n_epochs ${VALIDATE_EVERY} \
--reset_every_n_step ${RESET_TREE_EVERY} \
\
--num_iter ${NUM_ITER} \
--num_children ${NUM_CHILDREN} \
--buffer_size ${BUFFER_SIZE} \
--replay_buffer_size ${REPLAY_BUFFER_SIZE} \
--replay_buffer_strategy ${REPLAY_BUFFER_STRATEGY} \
--alpha ${ALPHA} \
--exploration ${EXPLORATION} \
\
--contrastive_weight ${CONTRASTIVE_WEIGHT} \
--contrastive_margin ${CONTRASTIVE_MARGIN} \
--kl_beta ${KL_BETA} \
--min_affinity_threshold ${MIN_AFFINITY_THRESHOLD} \
--sigmoid_temperature ${SIGMOID_TEMPERATURE} \
\
--direction_oracle_ckpt ${ORACLE_CKPT} \
--direction_oracle_tr2d2_checkpoint ${ORACLE_TR2D2_CHECKPOINT} \
--direction_oracle_tokenizer_vocab ${ORACLE_TOKENIZER_VOCAB} \
--direction_oracle_tokenizer_splits ${ORACLE_TOKENIZER_SPLITS} \
--direction_oracle_esm_name ${ORACLE_ESM_NAME} \
--direction_oracle_max_ligand_length ${ORACLE_MAX_LIGAND_LENGTH} \
--direction_oracle_max_protein_length ${ORACLE_MAX_PROTEIN_LENGTH} \
--direction_oracle_d_model ${ORACLE_D_MODEL} \
--direction_oracle_n_heads ${ORACLE_N_HEADS} \
--direction_oracle_n_self_attn_layers ${ORACLE_N_SELF_ATTN_LAYERS} \
--direction_oracle_n_bmca_layers ${ORACLE_N_BMCA_LAYERS} \
--direction_oracle_dropout ${ORACLE_DROPOUT} \
${EXTRA_ORACLE_ARGS} \
\
--val_samples_per_target ${VAL_SAMPLES_PER_TARGET} \
\
--grad_clip \
--gradnorm_clip 1.0 \
--wandb_project ${WANDB_PROJECT}"
# Add validation CSV if it exists
if [ -f "${VAL_CSV}" ]; then
CMD="${CMD} --val_csv ${VAL_CSV}"
echo "Validation CSV: ${VAL_CSV}"
else
echo "No validation CSV found (${VAL_CSV})"
echo "Skipping validation during training"
fi
# Add W&B entity if specified
if [ -n "${WANDB_ENTITY}" ]; then
CMD="${CMD} --wandb_entity ${WANDB_ENTITY}"
fi
echo ""
echo "Launching training..."
echo ""
# Execute
eval $CMD