File size: 8,978 Bytes
fd352d5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """
Configuration file for Gemma 3 Gist Token Autoencoder
"""
import torch
# Model configuration
# MODEL_NAME = "google/gemma-3-1b-pt"
MODEL_NAME = "google/gemma-3-270m"
# MODEL_NAME = "google/gemma-3-270m-it"
# MODEL_NAME = "google/gemma-3-1b-it"
NUM_GIST_TOKENS = 256
# GIST_LAYERS = list(range(0, 18))
# gist_layers = [i, j, k, ...]; layer 0 is not in the flag. At least one element, strictly increasing.
GIST_LAYERS = [11]
# GIST_LAYERS = [10, 17] # 18 layers in 270M model, 26 in 1B model; e.g. [5, 11, 17] for multi refinement
# Training configuration
INPUT_SEQ_LENGTH = 1024 # length of input text to encode
LEARNING_RATE = 5e-5
# LEARNING_RATE = 1e-4
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
STRATIFIED_SAMPLING = True # Collect data so each 100-char length bucket has ~equal count (no oversampling)
USE_CURRICULUM = True # Sort dataset by length (short→long) for curriculum training
NUM_EPOCHS = 2
ATTENTION_DROPOUT = 0.0
# Latent space augmentation (gist hidden states) during training.
# Gaussian noise with this std is added to gist latent representations before reconstruction.
# Helps the model learn robust representations. Set to 0 to disable.
LATENT_NOISE_STD = 0.1
# L1 regularization coefficient for gist token activations.
# Penalizes the mean absolute value of gist hidden states to encourage sparse latent representations.
# Set to 0 to disable.
GIST_L1_COEFF = 0.0
# KL divergence coefficient for gist activations toward N(0, I).
# Penalizes deviation of gist hidden states from standard normal (per dimension).
# Set to 0 to disable.
GIST_KL_COEFF = 0.001
# KL divergence coefficient against the frozen pre-trained reference model (RLHF-style).
# Penalizes KL(current || reference) on output token distributions to prevent drift.
# Reference model runs a standard causal LM forward on just the output tokens (no gist).
# Set to 0 to disable (reference model will not be loaded).
REF_MODEL_KL_COEFF = 0.0001
# Track ref KL divergence even when REF_MODEL_KL_COEFF=0 (loads ref model but does not add to loss).
# Useful for monitoring drift without penalizing it. Ignored when REF_MODEL_KL_COEFF > 0.
TRACK_REF_KL = True
# Randomly mask out this fraction of gist token positions from decoder attention during training.
# Each gist position is independently masked with this probability (attention set to -inf).
LATENT_GIST_MASK_PROB = 0
# Randomly replace this fraction of input (encoder) tokens with mask token during training.
# Only applies to content tokens in the input segment. Helps robustness. 0 = disable.
INPUT_MASK_PROB = 0
# Randomly shuffle sentences in the input during training. Fraction of examples to apply.
# Uses PySBD for robust sentence detection. 0 = disable. When > 0, pre-tokenized cache is disabled.
SHUFFLE_SENTENCES_PROB = 0.0 # just shuffle all the time
# Experiment: if True, only the first layer after the first gist layer reads gist activations.
# Layers >= (gist_layer_i + 2) revert to isolated-output masking.
ISOLATE_ABOVE_GIST = False
# Fine-tune embed only: train only token embeddings, freeze rest.
FINE_TUNE_EMBED_ONLY = False
# Module to train (usually model.embed_tokens)
EMBED_MODULE_TO_TRAIN = "model.embed_tokens"
# When True, only train gist token rows (~524K params). When False, train full embed table (~168M).
EMBED_GIST_TOKENS_ONLY = True
# Initialize gist token embeddings from phrase (when FINE_TUNE_EMBED_ONLY)
INIT_GIST_FROM_PHRASE = False
GIST_INIT_PHRASE = "remember what I've seen and repeat after me"
# LoRA configuration (when USE_LORA=True, FINE_TUNE_EMBED_ONLY=False)
USE_LORA = True
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = "q_proj,v_proj,k_proj,o_proj"
LORA_MODULES_TO_SAVE = "model.embed_tokens"
# LORA_MODULES_TO_SAVE = ""
# When True and model.embed_tokens is in LORA_MODULES_TO_SAVE, only train gist token rows
# via a gradient mask hook (same trick as EMBED_GIST_TOKENS_ONLY for fine_tune_embed_only).
LORA_EMBED_GIST_TOKENS_ONLY = True
# When True, apply LoRA only to layers 0..gist_layer_i (encoder side). Layers above the gist
# layer are frozen, halving the number of trainable LoRA parameters.
LORA_ENCODER_ONLY = False
# When True, zero out the LoRA delta for all output token positions (positions >=
# input_seq_length + num_gist_tokens) by hooking each LoRA B-matrix. Input and gist
# positions still use LoRA; the output reconstruction uses only frozen base weights.
LORA_OUTPUT_TOKENS_NO_LORA = False
# Layer-freeze mode (alternative to LoRA and fine_tune_embed_only).
# When set, freezes all parameters then unfreezes only the listed transformer layers
# plus gist token embeddings (via gradient mask). Set to None to disable.
# E.g. "10,11" → unfreeze the layer feeding into gist (10) and the gist layer (11).
TRAIN_LAYERS = None
# TRAIN_LAYERS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13"
# ---------------------------------------------------------------------------
# Inline Gist Codec — joint LoRA + codec training
# ---------------------------------------------------------------------------
# Pre-trained gist checkpoint to warm-start from (PEFT adapter directory).
# Adapter weights are loaded onto the base model; training starts fresh
# (no optimizer/scheduler state from the prior run).
# Set to None to start from the raw base model.
INIT_FROM_CHECKPOINT = "arkanathp/s1_curriculum_true_latent_0_1_with_gist_kl_0_0001_ref_kl_0_0001"
# INIT_FROM_CHECKPOINT = "arkanathp/s2_noise_0_1_latent_kl_0_01_ref_kl_0_01"
# INIT_FROM_CHECKPOINT = None
# If True, load INIT_FROM_CHECKPOINT as the reference model instead of the raw base model_name.
# Measures KL drift from the warm-start adapter (stage-1 checkpoint) rather than from raw pretrained weights.
REF_MODEL_FROM_CHECKPOINT = False
# Codec enable flag: 0 = disabled, any value > 0 = enabled.
# The codec is additive (identity at init); no token/dim compression.
CODEC_K = 0
# Self-attention depth for encoder and decoder blocks of the codec.
CODEC_ENCODE_LAYERS = 1
CODEC_DECODE_LAYERS = 1
# Codec loss weights.
# MSE weight: 0 = rely on CE loss alone (recommended for joint training).
CODEC_MSE_WEIGHT = 0.0
# KL weight: regularises compressed latent towards N(0, I).
CODEC_KL_WEIGHT = 0.001
# Codec learning rate (typically higher than main LoRA LR since it starts from scratch).
CODEC_LR = 1e-4
# Per-dimension ReZero gate for codec blocks.
# True = alpha shape (D,): each feature dimension has its own gate (more expressive).
# False = alpha shape (1,): single scalar gate shared across all dimensions (original behavior).
CODEC_PER_DIM_ALPHA = True
# False = joint mode: LoRA adapters + codec both train.
# True = codec-only mode: full LLM frozen, only codec trains.
CODEC_ONLY = False
# Dataset configuration
# DATASET_NAME = "the_pile" # Will use a subset of The Pile
# DATASET_NAME = "c4_realnewslike" # C4 news domains - high quality; stratified gets short+long
DATASET_NAME = "arkanathp/1M-stratigied-pile-uncopyrighted,arkanathp/400K-stratified-c4-realnewslike" # C4 news domains - high quality; stratified gets short+long
# DATASET_NAME = "arkanathp/400K-straitified-pile-uncopyrighted"
# DATASET_NAME = "bayes-group-diffusion/wikipedia" # Use DATASET_TEXT_COLUMN = "text_trg"
# DATASET_NAME = "arkanathp/400K-stratified-c4-realnewslike" # Private HF dataset (text column)
# DATASET_NAME = "dummy"
DATASET_SUBSET = None # For c4_realnewslike, None uses "realnewslike" automatically
DATASET_SPLIT = "train"
MAX_SAMPLES = 1_400_000 # total cap across all datasets
DATASET_TEXT_COLUMN = "text" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
# DATASET_TEXT_COLUMN = "text_trg" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
# Pile source filter (only for monology/pile-uncopyrighted).
# Pile has meta.pile_set_name: "Pile-CC", "Github", "StackExchange", "Wikipedia (en)", "ArXiv", etc.
# PILE_SOURCE_INCLUDE: if set, only keep these sources (e.g. ["Wikipedia (en)", "StackExchange", "ArXiv"])
# PILE_SOURCE_EXCLUDE: if set, skip these sources (e.g. ["Pile-CC", "Github"] to reduce noise)
PILE_SOURCE_INCLUDE = []
PILE_SOURCE_EXCLUDE = ["Pile-CC", "Github"] # Exclude noisy Common Crawl; set to [] to allow all
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# CHECKPOINT_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
# LOG_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
# Paths
CHECKPOINT_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
LOG_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
# Pre-tokenized cache: None = auto-derive from DATASET_NAME (cache/tokenized_{seq}_{gist}_{dataset}.pkl). "" = disabled.
TOKENIZED_CACHE_PATH = None
# Perplexity eval: held-out set path (local JSON/JSONL or HF hub ID). None = disabled.
PERPLEXITY_TEST_SET_PATH = "arkanathp/800_stratified_heldout_c4_realnewslike"
PERPLEXITY_EVAL_STEPS = 250
|