gRNAde / config.yaml
chaitjo's picture
Update config.yaml
f145ec5 verified
# Training configurations for gRNAde_drop3d@0.75_maxlen@500.h5
# Misc configurations
device:
value: 'gpu'
desc: Device to run on (cpu/cuda/xpu)
gpu:
value: 0
desc: GPU ID
seed:
value: 0
desc: Random seed for reproducibility
save:
value: True
desc: Whether to save current and best model checkpoint
# Data configurations
data_path:
value: "./data/"
desc: Data directory (preprocessed and raw)
radius:
value: 4.5
desc: Radius for determining local neighborhoods in Angstrom (currently not used)
top_k:
value: 32
desc: Number of k-nearest neighbors in 3D and sequence space
num_rbf:
value: 32
desc: Number of radial basis functions to featurise distances
num_posenc:
value: 32
desc: Number of positional encodings to featurise edges
max_num_conformers:
value: 1
desc: Maximum number of conformations sampled per sequence
noise_scale:
value: 0.1
desc: Std of gaussian noise added to node coordinates during training
drop_prob_3d:
value: 0.75
desc: Dropout probability of 3D coordinates during training
random_order:
value: True
desc: Whether to train with random permutation or sequential order
max_nodes_batch:
value: 3000
desc: Maximum number of nodes in batch
max_nodes_sample:
value: 500
desc: Maximum number of nodes in batches with single samples (ie. maximum RNA length)
# Splitting configurations
split:
value: 'das'
desc: Type of data split (das/structsim_v2)
# Model configurations
model:
value: 'gRNAde'
desc: Model architecture
node_in_dim:
value: [15, 4] # (num_bb_atoms x 5, 2 + (num_bb_atoms - 1))
desc: Input dimensions for node features (scalar channels, vector channels)
node_h_dim:
value: [128, 16]
desc: Hidden dimensions for node features (scalar channels, vector channels)
edge_in_dim:
value: [132, 3] # (num_bb_atoms x num_edge_type + num_rbf + num_posenc, num_bb_atoms)
desc: Input dimensions for edge features (scalar channels, vector channels)
edge_h_dim:
value: [64, 4]
desc: Hidden dimensions for edge features (scalar channels, vector channels)
num_layers:
value: 4
desc: Number of layers for encoder/decoder
drop_rate:
value: 0.5
desc: Dropout rate
out_dim:
value: 4
desc: Output dimension (4 bases for RNA)
# Training configurations
epochs:
value: 100
desc: Number of training epochs
lr:
value: 0.0001
desc: Learning rate
label_smoothing:
value: 0.05
desc: Label smoothing for cross entropy loss
batch_size:
value: 8
desc: Batch size for dataloaders (currently not used)
num_workers:
value: 16
desc: Number of workers for dataloaders
val_every:
value: 10
desc: Interval of training epochs after which validation is performed
# Evaluation configurations
model_path:
value: ''
desc: Path to model checkpoint for evaluation or reloading
evaluate:
value: False
desc: Whether to run evaluation (or training)
n_samples:
value: 16
desc: Number of samples for evaluating recovery
temperature:
value: 0.1
desc: Sampling temperature for evaluating recovery