Upload config.py with huggingface_hub
Browse files
config.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration file for Gemma 3 Gist Token Autoencoder
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# Model configuration
|
| 8 |
+
# MODEL_NAME = "google/gemma-3-1b-pt"
|
| 9 |
+
MODEL_NAME = "google/gemma-3-270m"
|
| 10 |
+
# MODEL_NAME = "google/gemma-3-270m-it"
|
| 11 |
+
# MODEL_NAME = "google/gemma-3-1b-it"
|
| 12 |
+
NUM_GIST_TOKENS = 256
|
| 13 |
+
# GIST_LAYERS = list(range(0, 18))
|
| 14 |
+
# gist_layers = [i, j, k, ...]; layer 0 is not in the flag. At least one element, strictly increasing.
|
| 15 |
+
GIST_LAYERS = [11]
|
| 16 |
+
# GIST_LAYERS = [10, 17] # 18 layers in 270M model, 26 in 1B model; e.g. [5, 11, 17] for multi refinement
|
| 17 |
+
|
| 18 |
+
# Training configuration
|
| 19 |
+
INPUT_SEQ_LENGTH = 1024 # length of input text to encode
|
| 20 |
+
LEARNING_RATE = 5e-5
|
| 21 |
+
# LEARNING_RATE = 1e-4
|
| 22 |
+
BATCH_SIZE = 4
|
| 23 |
+
GRADIENT_ACCUMULATION_STEPS = 4
|
| 24 |
+
STRATIFIED_SAMPLING = True # Collect data so each 100-char length bucket has ~equal count (no oversampling)
|
| 25 |
+
USE_CURRICULUM = False # Sort dataset by length (short→long) for curriculum training
|
| 26 |
+
NUM_EPOCHS = 2
|
| 27 |
+
|
| 28 |
+
ATTENTION_DROPOUT = 0.0
|
| 29 |
+
|
| 30 |
+
# Latent space augmentation (gist hidden states) during training.
|
| 31 |
+
# Gaussian noise with this std is added to gist latent representations before reconstruction.
|
| 32 |
+
# Helps the model learn robust representations. Set to 0 to disable.
|
| 33 |
+
LATENT_NOISE_STD = 0.1
|
| 34 |
+
|
| 35 |
+
# L1 regularization coefficient for gist token activations.
|
| 36 |
+
# Penalizes the mean absolute value of gist hidden states to encourage sparse latent representations.
|
| 37 |
+
# Set to 0 to disable.
|
| 38 |
+
GIST_L1_COEFF = 0.0
|
| 39 |
+
|
| 40 |
+
# KL divergence coefficient for gist activations toward N(0, I).
|
| 41 |
+
# Penalizes deviation of gist hidden states from standard normal (per dimension).
|
| 42 |
+
# Set to 0 to disable.
|
| 43 |
+
GIST_KL_COEFF = 0.001
|
| 44 |
+
|
| 45 |
+
# KL divergence coefficient against the frozen pre-trained reference model (RLHF-style).
|
| 46 |
+
# Penalizes KL(current || reference) on output token distributions to prevent drift.
|
| 47 |
+
# Reference model runs a standard causal LM forward on just the output tokens (no gist).
|
| 48 |
+
# Set to 0 to disable (reference model will not be loaded).
|
| 49 |
+
REF_MODEL_KL_COEFF = 0.0001
|
| 50 |
+
|
| 51 |
+
# Track ref KL divergence even when REF_MODEL_KL_COEFF=0 (loads ref model but does not add to loss).
|
| 52 |
+
# Useful for monitoring drift without penalizing it. Ignored when REF_MODEL_KL_COEFF > 0.
|
| 53 |
+
TRACK_REF_KL = True
|
| 54 |
+
|
| 55 |
+
# Randomly mask out this fraction of gist token positions from decoder attention during training.
|
| 56 |
+
# Each gist position is independently masked with this probability (attention set to -inf).
|
| 57 |
+
LATENT_GIST_MASK_PROB = 0
|
| 58 |
+
|
| 59 |
+
# Randomly replace this fraction of input (encoder) tokens with mask token during training.
|
| 60 |
+
# Only applies to content tokens in the input segment. Helps robustness. 0 = disable.
|
| 61 |
+
INPUT_MASK_PROB = 0
|
| 62 |
+
|
| 63 |
+
# Randomly shuffle sentences in the input during training. Fraction of examples to apply.
|
| 64 |
+
# Uses PySBD for robust sentence detection. 0 = disable. When > 0, pre-tokenized cache is disabled.
|
| 65 |
+
SHUFFLE_SENTENCES_PROB = 0.0 # just shuffle all the time
|
| 66 |
+
|
| 67 |
+
# Experiment: if True, only the first layer after the first gist layer reads gist activations.
|
| 68 |
+
# Layers >= (gist_layer_i + 2) revert to isolated-output masking.
|
| 69 |
+
ISOLATE_ABOVE_GIST = False
|
| 70 |
+
|
| 71 |
+
# Fine-tune embed only: train only token embeddings, freeze rest.
|
| 72 |
+
FINE_TUNE_EMBED_ONLY = False
|
| 73 |
+
# Module to train (usually model.embed_tokens)
|
| 74 |
+
EMBED_MODULE_TO_TRAIN = "model.embed_tokens"
|
| 75 |
+
# When True, only train gist token rows (~524K params). When False, train full embed table (~168M).
|
| 76 |
+
EMBED_GIST_TOKENS_ONLY = True
|
| 77 |
+
|
| 78 |
+
# Initialize gist token embeddings from phrase (when FINE_TUNE_EMBED_ONLY)
|
| 79 |
+
INIT_GIST_FROM_PHRASE = False
|
| 80 |
+
GIST_INIT_PHRASE = "remember what I've seen and repeat after me"
|
| 81 |
+
|
| 82 |
+
# LoRA configuration (when USE_LORA=True, FINE_TUNE_EMBED_ONLY=False)
|
| 83 |
+
USE_LORA = True
|
| 84 |
+
LORA_R = 8
|
| 85 |
+
LORA_ALPHA = 16
|
| 86 |
+
LORA_DROPOUT = 0.05
|
| 87 |
+
LORA_TARGET_MODULES = "q_proj,v_proj,k_proj,o_proj"
|
| 88 |
+
LORA_MODULES_TO_SAVE = "model.embed_tokens"
|
| 89 |
+
# LORA_MODULES_TO_SAVE = ""
|
| 90 |
+
# When True and model.embed_tokens is in LORA_MODULES_TO_SAVE, only train gist token rows
|
| 91 |
+
# via a gradient mask hook (same trick as EMBED_GIST_TOKENS_ONLY for fine_tune_embed_only).
|
| 92 |
+
LORA_EMBED_GIST_TOKENS_ONLY = True
|
| 93 |
+
# When True, apply LoRA only to layers 0..gist_layer_i (encoder side). Layers above the gist
|
| 94 |
+
# layer are frozen, halving the number of trainable LoRA parameters.
|
| 95 |
+
LORA_ENCODER_ONLY = False
|
| 96 |
+
# When True, zero out the LoRA delta for all output token positions (positions >=
|
| 97 |
+
# input_seq_length + num_gist_tokens) by hooking each LoRA B-matrix. Input and gist
|
| 98 |
+
# positions still use LoRA; the output reconstruction uses only frozen base weights.
|
| 99 |
+
LORA_OUTPUT_TOKENS_NO_LORA = False
|
| 100 |
+
|
| 101 |
+
# Layer-freeze mode (alternative to LoRA and fine_tune_embed_only).
|
| 102 |
+
# When set, freezes all parameters then unfreezes only the listed transformer layers
|
| 103 |
+
# plus gist token embeddings (via gradient mask). Set to None to disable.
|
| 104 |
+
# E.g. "10,11" → unfreeze the layer feeding into gist (10) and the gist layer (11).
|
| 105 |
+
TRAIN_LAYERS = None
|
| 106 |
+
# TRAIN_LAYERS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13"
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Inline Gist Codec — joint LoRA + codec training
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# Pre-trained gist checkpoint to warm-start from (PEFT adapter directory).
|
| 112 |
+
# Adapter weights are loaded onto the base model; training starts fresh
|
| 113 |
+
# (no optimizer/scheduler state from the prior run).
|
| 114 |
+
# Set to None to start from the raw base model.
|
| 115 |
+
INIT_FROM_CHECKPOINT = "arkanathp/s1_curriculum_true_latent_0_1_with_gist_kl_0_0001_ref_kl_0_0001"
|
| 116 |
+
# INIT_FROM_CHECKPOINT = "arkanathp/s2_noise_0_1_latent_kl_0_01_ref_kl_0_01"
|
| 117 |
+
# INIT_FROM_CHECKPOINT = None
|
| 118 |
+
|
| 119 |
+
# If True, load INIT_FROM_CHECKPOINT as the reference model instead of the raw base model_name.
|
| 120 |
+
# Measures KL drift from the warm-start adapter (stage-1 checkpoint) rather than from raw pretrained weights.
|
| 121 |
+
REF_MODEL_FROM_CHECKPOINT = False
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# Codec enable flag: 0 = disabled, any value > 0 = enabled.
|
| 125 |
+
# The codec is additive (identity at init); no token/dim compression.
|
| 126 |
+
CODEC_K = 0
|
| 127 |
+
|
| 128 |
+
# Self-attention depth for encoder and decoder blocks of the codec.
|
| 129 |
+
CODEC_ENCODE_LAYERS = 1
|
| 130 |
+
CODEC_DECODE_LAYERS = 1
|
| 131 |
+
|
| 132 |
+
# Codec loss weights.
|
| 133 |
+
# MSE weight: 0 = rely on CE loss alone (recommended for joint training).
|
| 134 |
+
CODEC_MSE_WEIGHT = 0.0
|
| 135 |
+
# KL weight: regularises compressed latent towards N(0, I).
|
| 136 |
+
CODEC_KL_WEIGHT = 0.001
|
| 137 |
+
|
| 138 |
+
# Codec learning rate (typically higher than main LoRA LR since it starts from scratch).
|
| 139 |
+
CODEC_LR = 1e-4
|
| 140 |
+
|
| 141 |
+
# Per-dimension ReZero gate for codec blocks.
|
| 142 |
+
# True = alpha shape (D,): each feature dimension has its own gate (more expressive).
|
| 143 |
+
# False = alpha shape (1,): single scalar gate shared across all dimensions (original behavior).
|
| 144 |
+
CODEC_PER_DIM_ALPHA = True
|
| 145 |
+
|
| 146 |
+
# False = joint mode: LoRA adapters + codec both train.
|
| 147 |
+
# True = codec-only mode: full LLM frozen, only codec trains.
|
| 148 |
+
CODEC_ONLY = False
|
| 149 |
+
|
| 150 |
+
# Dataset configuration
|
| 151 |
+
# DATASET_NAME = "the_pile" # Will use a subset of The Pile
|
| 152 |
+
# DATASET_NAME = "c4_realnewslike" # C4 news domains - high quality; stratified gets short+long
|
| 153 |
+
DATASET_NAME = "arkanathp/1M-stratigied-pile-uncopyrighted,arkanathp/400K-stratified-c4-realnewslike" # C4 news domains - high quality; stratified gets short+long
|
| 154 |
+
# DATASET_NAME = "arkanathp/400K-straitified-pile-uncopyrighted"
|
| 155 |
+
# DATASET_NAME = "bayes-group-diffusion/wikipedia" # Use DATASET_TEXT_COLUMN = "text_trg"
|
| 156 |
+
# DATASET_NAME = "arkanathp/400K-stratified-c4-realnewslike" # Private HF dataset (text column)
|
| 157 |
+
# DATASET_NAME = "dummy"
|
| 158 |
+
DATASET_SUBSET = None # For c4_realnewslike, None uses "realnewslike" automatically
|
| 159 |
+
DATASET_SPLIT = "train"
|
| 160 |
+
MAX_SAMPLES = 1_400_000 # total cap across all datasets
|
| 161 |
+
DATASET_TEXT_COLUMN = "text" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
|
| 162 |
+
# DATASET_TEXT_COLUMN = "text_trg" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
|
| 163 |
+
|
| 164 |
+
# Pile source filter (only for monology/pile-uncopyrighted).
|
| 165 |
+
# Pile has meta.pile_set_name: "Pile-CC", "Github", "StackExchange", "Wikipedia (en)", "ArXiv", etc.
|
| 166 |
+
# PILE_SOURCE_INCLUDE: if set, only keep these sources (e.g. ["Wikipedia (en)", "StackExchange", "ArXiv"])
|
| 167 |
+
# PILE_SOURCE_EXCLUDE: if set, skip these sources (e.g. ["Pile-CC", "Github"] to reduce noise)
|
| 168 |
+
PILE_SOURCE_INCLUDE = []
|
| 169 |
+
PILE_SOURCE_EXCLUDE = ["Pile-CC", "Github"] # Exclude noisy Common Crawl; set to [] to allow all
|
| 170 |
+
|
| 171 |
+
# Device configuration
|
| 172 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 173 |
+
|
| 174 |
+
# CHECKPOINT_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
|
| 175 |
+
# LOG_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
|
| 176 |
+
# Paths
|
| 177 |
+
CHECKPOINT_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
|
| 178 |
+
LOG_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
|
| 179 |
+
# Pre-tokenized cache: None = auto-derive from DATASET_NAME (cache/tokenized_{seq}_{gist}_{dataset}.pkl). "" = disabled.
|
| 180 |
+
TOKENIZED_CACHE_PATH = None
|
| 181 |
+
|
| 182 |
+
# Perplexity eval: held-out set path (local JSON/JSONL or HF hub ID). None = disabled.
|
| 183 |
+
PERPLEXITY_TEST_SET_PATH = "arkanathp/800_stratified_heldout_c4_realnewslike"
|
| 184 |
+
PERPLEXITY_EVAL_STEPS = 250
|