| """ | |
| Configuration file for Gemma 3 Gist Token Autoencoder | |
| """ | |
| import torch | |
| # Model configuration | |
| # MODEL_NAME = "google/gemma-3-1b-pt" | |
| MODEL_NAME = "google/gemma-3-270m" | |
| # MODEL_NAME = "google/gemma-3-270m-it" | |
| # MODEL_NAME = "google/gemma-3-1b-it" | |
| NUM_GIST_TOKENS = 256 | |
| # GIST_LAYERS = list(range(0, 18)) | |
| # gist_layers = [i, j, k, ...]; layer 0 is not in the flag. At least one element, strictly increasing. | |
| GIST_LAYERS = [11] | |
| # GIST_LAYERS = [10, 17] # 18 layers in 270M model, 26 in 1B model; e.g. [5, 11, 17] for multi refinement | |
| # Training configuration | |
| INPUT_SEQ_LENGTH = 1024 # length of input text to encode | |
| LEARNING_RATE = 5e-5 | |
| # LEARNING_RATE = 1e-4 | |
| BATCH_SIZE = 4 | |
| GRADIENT_ACCUMULATION_STEPS = 4 | |
| STRATIFIED_SAMPLING = True # Collect data so each 100-char length bucket has ~equal count (no oversampling) | |
| USE_CURRICULUM = False # Sort dataset by length (short→long) for curriculum training | |
| NUM_EPOCHS = 2 | |
| ATTENTION_DROPOUT = 0.0 | |
| # Latent space augmentation (gist hidden states) during training. | |
| # Gaussian noise with this std is added to gist latent representations before reconstruction. | |
| # Helps the model learn robust representations. Set to 0 to disable. | |
| LATENT_NOISE_STD = 0.1 | |
| # L1 regularization coefficient for gist token activations. | |
| # Penalizes the mean absolute value of gist hidden states to encourage sparse latent representations. | |
| # Set to 0 to disable. | |
| GIST_L1_COEFF = 0.0 | |
| # KL divergence coefficient for gist activations toward N(0, I). | |
| # Penalizes deviation of gist hidden states from standard normal (per dimension). | |
| # Set to 0 to disable. | |
| GIST_KL_COEFF = 0.001 | |
| # KL divergence coefficient against the frozen pre-trained reference model (RLHF-style). | |
| # Penalizes KL(current || reference) on output token distributions to prevent drift. | |
| # Reference model runs a standard causal LM forward on just the output tokens (no gist). | |
| # Set to 0 to disable (reference model will not be loaded). | |
| REF_MODEL_KL_COEFF = 0.0001 | |
| # Track ref KL divergence even when REF_MODEL_KL_COEFF=0 (loads ref model but does not add to loss). | |
| # Useful for monitoring drift without penalizing it. Ignored when REF_MODEL_KL_COEFF > 0. | |
| TRACK_REF_KL = True | |
| # Randomly mask out this fraction of gist token positions from decoder attention during training. | |
| # Each gist position is independently masked with this probability (attention set to -inf). | |
| LATENT_GIST_MASK_PROB = 0 | |
| # Randomly replace this fraction of input (encoder) tokens with mask token during training. | |
| # Only applies to content tokens in the input segment. Helps robustness. 0 = disable. | |
| INPUT_MASK_PROB = 0 | |
| # Randomly shuffle sentences in the input during training. Fraction of examples to apply. | |
| # Uses PySBD for robust sentence detection. 0 = disable. When > 0, pre-tokenized cache is disabled. | |
| SHUFFLE_SENTENCES_PROB = 0.0 # just shuffle all the time | |
| # Experiment: if True, only the first layer after the first gist layer reads gist activations. | |
| # Layers >= (gist_layer_i + 2) revert to isolated-output masking. | |
| ISOLATE_ABOVE_GIST = False | |
| # Fine-tune embed only: train only token embeddings, freeze rest. | |
| FINE_TUNE_EMBED_ONLY = False | |
| # Module to train (usually model.embed_tokens) | |
| EMBED_MODULE_TO_TRAIN = "model.embed_tokens" | |
| # When True, only train gist token rows (~524K params). When False, train full embed table (~168M). | |
| EMBED_GIST_TOKENS_ONLY = True | |
| # Initialize gist token embeddings from phrase (when FINE_TUNE_EMBED_ONLY) | |
| INIT_GIST_FROM_PHRASE = False | |
| GIST_INIT_PHRASE = "remember what I've seen and repeat after me" | |
| # LoRA configuration (when USE_LORA=True, FINE_TUNE_EMBED_ONLY=False) | |
| USE_LORA = True | |
| LORA_R = 8 | |
| LORA_ALPHA = 16 | |
| LORA_DROPOUT = 0.05 | |
| LORA_TARGET_MODULES = "q_proj,v_proj,k_proj,o_proj" | |
| LORA_MODULES_TO_SAVE = "model.embed_tokens" | |
| # LORA_MODULES_TO_SAVE = "" | |
| # When True and model.embed_tokens is in LORA_MODULES_TO_SAVE, only train gist token rows | |
| # via a gradient mask hook (same trick as EMBED_GIST_TOKENS_ONLY for fine_tune_embed_only). | |
| LORA_EMBED_GIST_TOKENS_ONLY = True | |
| # When True, apply LoRA only to layers 0..gist_layer_i (encoder side). Layers above the gist | |
| # layer are frozen, halving the number of trainable LoRA parameters. | |
| LORA_ENCODER_ONLY = False | |
| # When True, zero out the LoRA delta for all output token positions (positions >= | |
| # input_seq_length + num_gist_tokens) by hooking each LoRA B-matrix. Input and gist | |
| # positions still use LoRA; the output reconstruction uses only frozen base weights. | |
| LORA_OUTPUT_TOKENS_NO_LORA = False | |
| # Layer-freeze mode (alternative to LoRA and fine_tune_embed_only). | |
| # When set, freezes all parameters then unfreezes only the listed transformer layers | |
| # plus gist token embeddings (via gradient mask). Set to None to disable. | |
| # E.g. "10,11" → unfreeze the layer feeding into gist (10) and the gist layer (11). | |
| TRAIN_LAYERS = None | |
| # TRAIN_LAYERS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13" | |
| # --------------------------------------------------------------------------- | |
| # Inline Gist Codec — joint LoRA + codec training | |
| # --------------------------------------------------------------------------- | |
| # Pre-trained gist checkpoint to warm-start from (PEFT adapter directory). | |
| # Adapter weights are loaded onto the base model; training starts fresh | |
| # (no optimizer/scheduler state from the prior run). | |
| # Set to None to start from the raw base model. | |
| INIT_FROM_CHECKPOINT = "arkanathp/s1_curriculum_true_latent_0_1_with_gist_kl_0_0001_ref_kl_0_0001" | |
| # INIT_FROM_CHECKPOINT = "arkanathp/s2_noise_0_1_latent_kl_0_01_ref_kl_0_01" | |
| # INIT_FROM_CHECKPOINT = None | |
| # If True, load INIT_FROM_CHECKPOINT as the reference model instead of the raw base model_name. | |
| # Measures KL drift from the warm-start adapter (stage-1 checkpoint) rather than from raw pretrained weights. | |
| REF_MODEL_FROM_CHECKPOINT = False | |
| # Codec enable flag: 0 = disabled, any value > 0 = enabled. | |
| # The codec is additive (identity at init); no token/dim compression. | |
| CODEC_K = 0 | |
| # Self-attention depth for encoder and decoder blocks of the codec. | |
| CODEC_ENCODE_LAYERS = 1 | |
| CODEC_DECODE_LAYERS = 1 | |
| # Codec loss weights. | |
| # MSE weight: 0 = rely on CE loss alone (recommended for joint training). | |
| CODEC_MSE_WEIGHT = 0.0 | |
| # KL weight: regularises compressed latent towards N(0, I). | |
| CODEC_KL_WEIGHT = 0.001 | |
| # Codec learning rate (typically higher than main LoRA LR since it starts from scratch). | |
| CODEC_LR = 1e-4 | |
| # Per-dimension ReZero gate for codec blocks. | |
| # True = alpha shape (D,): each feature dimension has its own gate (more expressive). | |
| # False = alpha shape (1,): single scalar gate shared across all dimensions (original behavior). | |
| CODEC_PER_DIM_ALPHA = True | |
| # False = joint mode: LoRA adapters + codec both train. | |
| # True = codec-only mode: full LLM frozen, only codec trains. | |
| CODEC_ONLY = False | |
| # Dataset configuration | |
| # DATASET_NAME = "the_pile" # Will use a subset of The Pile | |
| # DATASET_NAME = "c4_realnewslike" # C4 news domains - high quality; stratified gets short+long | |
| DATASET_NAME = "arkanathp/1M-stratigied-pile-uncopyrighted,arkanathp/400K-stratified-c4-realnewslike" # C4 news domains - high quality; stratified gets short+long | |
| # DATASET_NAME = "arkanathp/400K-straitified-pile-uncopyrighted" | |
| # DATASET_NAME = "bayes-group-diffusion/wikipedia" # Use DATASET_TEXT_COLUMN = "text_trg" | |
| # DATASET_NAME = "arkanathp/400K-stratified-c4-realnewslike" # Private HF dataset (text column) | |
| # DATASET_NAME = "dummy" | |
| DATASET_SUBSET = None # For c4_realnewslike, None uses "realnewslike" automatically | |
| DATASET_SPLIT = "train" | |
| MAX_SAMPLES = 1_400_000 # total cap across all datasets | |
| DATASET_TEXT_COLUMN = "text" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia | |
| # DATASET_TEXT_COLUMN = "text_trg" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia | |
| # Pile source filter (only for monology/pile-uncopyrighted). | |
| # Pile has meta.pile_set_name: "Pile-CC", "Github", "StackExchange", "Wikipedia (en)", "ArXiv", etc. | |
| # PILE_SOURCE_INCLUDE: if set, only keep these sources (e.g. ["Wikipedia (en)", "StackExchange", "ArXiv"]) | |
| # PILE_SOURCE_EXCLUDE: if set, skip these sources (e.g. ["Pile-CC", "Github"] to reduce noise) | |
| PILE_SOURCE_INCLUDE = [] | |
| PILE_SOURCE_EXCLUDE = ["Pile-CC", "Github"] # Exclude noisy Common Crawl; set to [] to allow all | |
| # Device configuration | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # CHECKPOINT_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001" | |
| # LOG_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001" | |
| # Paths | |
| CHECKPOINT_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec" | |
| LOG_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec" | |
| # Pre-tokenized cache: None = auto-derive from DATASET_NAME (cache/tokenized_{seq}_{gist}_{dataset}.pkl). "" = disabled. | |
| TOKENIZED_CACHE_PATH = None | |
| # Perplexity eval: held-out set path (local JSON/JSONL or HF hub ID). None = disabled. | |
| PERPLEXITY_TEST_SET_PATH = "arkanathp/800_stratified_heldout_c4_realnewslike" | |
| PERPLEXITY_EVAL_STEPS = 250 | |