""" Configuration file for Gemma 3 Gist Token Autoencoder """ import torch # Model configuration # MODEL_NAME = "google/gemma-3-1b-pt" MODEL_NAME = "google/gemma-3-270m" # MODEL_NAME = "google/gemma-3-270m-it" # MODEL_NAME = "google/gemma-3-1b-it" NUM_GIST_TOKENS = 256 # GIST_LAYERS = list(range(0, 18)) # gist_layers = [i, j, k, ...]; layer 0 is not in the flag. At least one element, strictly increasing. GIST_LAYERS = [11] # GIST_LAYERS = [10, 17] # 18 layers in 270M model, 26 in 1B model; e.g. [5, 11, 17] for multi refinement # Training configuration INPUT_SEQ_LENGTH = 1024 # length of input text to encode LEARNING_RATE = 5e-5 # LEARNING_RATE = 1e-4 BATCH_SIZE = 4 GRADIENT_ACCUMULATION_STEPS = 4 STRATIFIED_SAMPLING = True # Collect data so each 100-char length bucket has ~equal count (no oversampling) USE_CURRICULUM = True # Sort dataset by length (short→long) for curriculum training NUM_EPOCHS = 2 ATTENTION_DROPOUT = 0.0 # Latent space augmentation (gist hidden states) during training. # Gaussian noise with this std is added to gist latent representations before reconstruction. # Helps the model learn robust representations. Set to 0 to disable. LATENT_NOISE_STD = 0.1 # L1 regularization coefficient for gist token activations. # Penalizes the mean absolute value of gist hidden states to encourage sparse latent representations. # Set to 0 to disable. GIST_L1_COEFF = 0.0 # KL divergence coefficient for gist activations toward N(0, I). # Penalizes deviation of gist hidden states from standard normal (per dimension). # Set to 0 to disable. GIST_KL_COEFF = 0.001 # KL divergence coefficient against the frozen pre-trained reference model (RLHF-style). # Penalizes KL(current || reference) on output token distributions to prevent drift. # Reference model runs a standard causal LM forward on just the output tokens (no gist). # Set to 0 to disable (reference model will not be loaded). REF_MODEL_KL_COEFF = 0.0001 # Track ref KL divergence even when REF_MODEL_KL_COEFF=0 (loads ref model but does not add to loss). # Useful for monitoring drift without penalizing it. Ignored when REF_MODEL_KL_COEFF > 0. TRACK_REF_KL = True # Randomly mask out this fraction of gist token positions from decoder attention during training. # Each gist position is independently masked with this probability (attention set to -inf). LATENT_GIST_MASK_PROB = 0 # Randomly replace this fraction of input (encoder) tokens with mask token during training. # Only applies to content tokens in the input segment. Helps robustness. 0 = disable. INPUT_MASK_PROB = 0 # Randomly shuffle sentences in the input during training. Fraction of examples to apply. # Uses PySBD for robust sentence detection. 0 = disable. When > 0, pre-tokenized cache is disabled. SHUFFLE_SENTENCES_PROB = 0.0 # just shuffle all the time # Experiment: if True, only the first layer after the first gist layer reads gist activations. # Layers >= (gist_layer_i + 2) revert to isolated-output masking. ISOLATE_ABOVE_GIST = False # Fine-tune embed only: train only token embeddings, freeze rest. FINE_TUNE_EMBED_ONLY = False # Module to train (usually model.embed_tokens) EMBED_MODULE_TO_TRAIN = "model.embed_tokens" # When True, only train gist token rows (~524K params). When False, train full embed table (~168M). EMBED_GIST_TOKENS_ONLY = True # Initialize gist token embeddings from phrase (when FINE_TUNE_EMBED_ONLY) INIT_GIST_FROM_PHRASE = False GIST_INIT_PHRASE = "remember what I've seen and repeat after me" # LoRA configuration (when USE_LORA=True, FINE_TUNE_EMBED_ONLY=False) USE_LORA = True LORA_R = 8 LORA_ALPHA = 16 LORA_DROPOUT = 0.05 LORA_TARGET_MODULES = "q_proj,v_proj,k_proj,o_proj" LORA_MODULES_TO_SAVE = "model.embed_tokens" # LORA_MODULES_TO_SAVE = "" # When True and model.embed_tokens is in LORA_MODULES_TO_SAVE, only train gist token rows # via a gradient mask hook (same trick as EMBED_GIST_TOKENS_ONLY for fine_tune_embed_only). LORA_EMBED_GIST_TOKENS_ONLY = True # When True, apply LoRA only to layers 0..gist_layer_i (encoder side). Layers above the gist # layer are frozen, halving the number of trainable LoRA parameters. LORA_ENCODER_ONLY = False # When True, zero out the LoRA delta for all output token positions (positions >= # input_seq_length + num_gist_tokens) by hooking each LoRA B-matrix. Input and gist # positions still use LoRA; the output reconstruction uses only frozen base weights. LORA_OUTPUT_TOKENS_NO_LORA = False # Layer-freeze mode (alternative to LoRA and fine_tune_embed_only). # When set, freezes all parameters then unfreezes only the listed transformer layers # plus gist token embeddings (via gradient mask). Set to None to disable. # E.g. "10,11" → unfreeze the layer feeding into gist (10) and the gist layer (11). TRAIN_LAYERS = None # TRAIN_LAYERS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13" # --------------------------------------------------------------------------- # Inline Gist Codec — joint LoRA + codec training # --------------------------------------------------------------------------- # Pre-trained gist checkpoint to warm-start from (PEFT adapter directory). # Adapter weights are loaded onto the base model; training starts fresh # (no optimizer/scheduler state from the prior run). # Set to None to start from the raw base model. INIT_FROM_CHECKPOINT = "arkanathp/s1_curriculum_true_latent_0_1_with_gist_kl_0_0001_ref_kl_0_0001" # INIT_FROM_CHECKPOINT = "arkanathp/s2_noise_0_1_latent_kl_0_01_ref_kl_0_01" # INIT_FROM_CHECKPOINT = None # If True, load INIT_FROM_CHECKPOINT as the reference model instead of the raw base model_name. # Measures KL drift from the warm-start adapter (stage-1 checkpoint) rather than from raw pretrained weights. REF_MODEL_FROM_CHECKPOINT = False # Codec enable flag: 0 = disabled, any value > 0 = enabled. # The codec is additive (identity at init); no token/dim compression. CODEC_K = 0 # Self-attention depth for encoder and decoder blocks of the codec. CODEC_ENCODE_LAYERS = 1 CODEC_DECODE_LAYERS = 1 # Codec loss weights. # MSE weight: 0 = rely on CE loss alone (recommended for joint training). CODEC_MSE_WEIGHT = 0.0 # KL weight: regularises compressed latent towards N(0, I). CODEC_KL_WEIGHT = 0.001 # Codec learning rate (typically higher than main LoRA LR since it starts from scratch). CODEC_LR = 1e-4 # Per-dimension ReZero gate for codec blocks. # True = alpha shape (D,): each feature dimension has its own gate (more expressive). # False = alpha shape (1,): single scalar gate shared across all dimensions (original behavior). CODEC_PER_DIM_ALPHA = True # False = joint mode: LoRA adapters + codec both train. # True = codec-only mode: full LLM frozen, only codec trains. CODEC_ONLY = False # Dataset configuration # DATASET_NAME = "the_pile" # Will use a subset of The Pile # DATASET_NAME = "c4_realnewslike" # C4 news domains - high quality; stratified gets short+long DATASET_NAME = "arkanathp/1M-stratigied-pile-uncopyrighted,arkanathp/400K-stratified-c4-realnewslike" # C4 news domains - high quality; stratified gets short+long # DATASET_NAME = "arkanathp/400K-straitified-pile-uncopyrighted" # DATASET_NAME = "bayes-group-diffusion/wikipedia" # Use DATASET_TEXT_COLUMN = "text_trg" # DATASET_NAME = "arkanathp/400K-stratified-c4-realnewslike" # Private HF dataset (text column) # DATASET_NAME = "dummy" DATASET_SUBSET = None # For c4_realnewslike, None uses "realnewslike" automatically DATASET_SPLIT = "train" MAX_SAMPLES = 1_400_000 # total cap across all datasets DATASET_TEXT_COLUMN = "text" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia # DATASET_TEXT_COLUMN = "text_trg" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia # Pile source filter (only for monology/pile-uncopyrighted). # Pile has meta.pile_set_name: "Pile-CC", "Github", "StackExchange", "Wikipedia (en)", "ArXiv", etc. # PILE_SOURCE_INCLUDE: if set, only keep these sources (e.g. ["Wikipedia (en)", "StackExchange", "ArXiv"]) # PILE_SOURCE_EXCLUDE: if set, skip these sources (e.g. ["Pile-CC", "Github"] to reduce noise) PILE_SOURCE_INCLUDE = [] PILE_SOURCE_EXCLUDE = ["Pile-CC", "Github"] # Exclude noisy Common Crawl; set to [] to allow all # Device configuration DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # CHECKPOINT_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001" # LOG_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001" # Paths CHECKPOINT_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec" LOG_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec" # Pre-tokenized cache: None = auto-derive from DATASET_NAME (cache/tokenized_{seq}_{gist}_{dataset}.pkl). "" = disabled. TOKENIZED_CACHE_PATH = None # Perplexity eval: held-out set path (local JSON/JSONL or HF hub ID). None = disabled. PERPLEXITY_TEST_SET_PATH = "arkanathp/800_stratified_heldout_c4_realnewslike" PERPLEXITY_EVAL_STEPS = 250