arkanathp's picture
Upload config.py with huggingface_hub
fd352d5 verified
"""
Configuration file for Gemma 3 Gist Token Autoencoder
"""
import torch
# Model configuration
# MODEL_NAME = "google/gemma-3-1b-pt"
MODEL_NAME = "google/gemma-3-270m"
# MODEL_NAME = "google/gemma-3-270m-it"
# MODEL_NAME = "google/gemma-3-1b-it"
NUM_GIST_TOKENS = 256
# GIST_LAYERS = list(range(0, 18))
# gist_layers = [i, j, k, ...]; layer 0 is not in the flag. At least one element, strictly increasing.
GIST_LAYERS = [11]
# GIST_LAYERS = [10, 17] # 18 layers in 270M model, 26 in 1B model; e.g. [5, 11, 17] for multi refinement
# Training configuration
INPUT_SEQ_LENGTH = 1024 # length of input text to encode
LEARNING_RATE = 5e-5
# LEARNING_RATE = 1e-4
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
STRATIFIED_SAMPLING = True # Collect data so each 100-char length bucket has ~equal count (no oversampling)
USE_CURRICULUM = True # Sort dataset by length (short→long) for curriculum training
NUM_EPOCHS = 2
ATTENTION_DROPOUT = 0.0
# Latent space augmentation (gist hidden states) during training.
# Gaussian noise with this std is added to gist latent representations before reconstruction.
# Helps the model learn robust representations. Set to 0 to disable.
LATENT_NOISE_STD = 0.1
# L1 regularization coefficient for gist token activations.
# Penalizes the mean absolute value of gist hidden states to encourage sparse latent representations.
# Set to 0 to disable.
GIST_L1_COEFF = 0.0
# KL divergence coefficient for gist activations toward N(0, I).
# Penalizes deviation of gist hidden states from standard normal (per dimension).
# Set to 0 to disable.
GIST_KL_COEFF = 0.001
# KL divergence coefficient against the frozen pre-trained reference model (RLHF-style).
# Penalizes KL(current || reference) on output token distributions to prevent drift.
# Reference model runs a standard causal LM forward on just the output tokens (no gist).
# Set to 0 to disable (reference model will not be loaded).
REF_MODEL_KL_COEFF = 0.0001
# Track ref KL divergence even when REF_MODEL_KL_COEFF=0 (loads ref model but does not add to loss).
# Useful for monitoring drift without penalizing it. Ignored when REF_MODEL_KL_COEFF > 0.
TRACK_REF_KL = True
# Randomly mask out this fraction of gist token positions from decoder attention during training.
# Each gist position is independently masked with this probability (attention set to -inf).
LATENT_GIST_MASK_PROB = 0
# Randomly replace this fraction of input (encoder) tokens with mask token during training.
# Only applies to content tokens in the input segment. Helps robustness. 0 = disable.
INPUT_MASK_PROB = 0
# Randomly shuffle sentences in the input during training. Fraction of examples to apply.
# Uses PySBD for robust sentence detection. 0 = disable. When > 0, pre-tokenized cache is disabled.
SHUFFLE_SENTENCES_PROB = 0.0 # just shuffle all the time
# Experiment: if True, only the first layer after the first gist layer reads gist activations.
# Layers >= (gist_layer_i + 2) revert to isolated-output masking.
ISOLATE_ABOVE_GIST = False
# Fine-tune embed only: train only token embeddings, freeze rest.
FINE_TUNE_EMBED_ONLY = False
# Module to train (usually model.embed_tokens)
EMBED_MODULE_TO_TRAIN = "model.embed_tokens"
# When True, only train gist token rows (~524K params). When False, train full embed table (~168M).
EMBED_GIST_TOKENS_ONLY = True
# Initialize gist token embeddings from phrase (when FINE_TUNE_EMBED_ONLY)
INIT_GIST_FROM_PHRASE = False
GIST_INIT_PHRASE = "remember what I've seen and repeat after me"
# LoRA configuration (when USE_LORA=True, FINE_TUNE_EMBED_ONLY=False)
USE_LORA = True
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = "q_proj,v_proj,k_proj,o_proj"
LORA_MODULES_TO_SAVE = "model.embed_tokens"
# LORA_MODULES_TO_SAVE = ""
# When True and model.embed_tokens is in LORA_MODULES_TO_SAVE, only train gist token rows
# via a gradient mask hook (same trick as EMBED_GIST_TOKENS_ONLY for fine_tune_embed_only).
LORA_EMBED_GIST_TOKENS_ONLY = True
# When True, apply LoRA only to layers 0..gist_layer_i (encoder side). Layers above the gist
# layer are frozen, halving the number of trainable LoRA parameters.
LORA_ENCODER_ONLY = False
# When True, zero out the LoRA delta for all output token positions (positions >=
# input_seq_length + num_gist_tokens) by hooking each LoRA B-matrix. Input and gist
# positions still use LoRA; the output reconstruction uses only frozen base weights.
LORA_OUTPUT_TOKENS_NO_LORA = False
# Layer-freeze mode (alternative to LoRA and fine_tune_embed_only).
# When set, freezes all parameters then unfreezes only the listed transformer layers
# plus gist token embeddings (via gradient mask). Set to None to disable.
# E.g. "10,11" → unfreeze the layer feeding into gist (10) and the gist layer (11).
TRAIN_LAYERS = None
# TRAIN_LAYERS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13"
# ---------------------------------------------------------------------------
# Inline Gist Codec — joint LoRA + codec training
# ---------------------------------------------------------------------------
# Pre-trained gist checkpoint to warm-start from (PEFT adapter directory).
# Adapter weights are loaded onto the base model; training starts fresh
# (no optimizer/scheduler state from the prior run).
# Set to None to start from the raw base model.
INIT_FROM_CHECKPOINT = "arkanathp/s1_curriculum_true_latent_0_1_with_gist_kl_0_0001_ref_kl_0_0001"
# INIT_FROM_CHECKPOINT = "arkanathp/s2_noise_0_1_latent_kl_0_01_ref_kl_0_01"
# INIT_FROM_CHECKPOINT = None
# If True, load INIT_FROM_CHECKPOINT as the reference model instead of the raw base model_name.
# Measures KL drift from the warm-start adapter (stage-1 checkpoint) rather than from raw pretrained weights.
REF_MODEL_FROM_CHECKPOINT = False
# Codec enable flag: 0 = disabled, any value > 0 = enabled.
# The codec is additive (identity at init); no token/dim compression.
CODEC_K = 0
# Self-attention depth for encoder and decoder blocks of the codec.
CODEC_ENCODE_LAYERS = 1
CODEC_DECODE_LAYERS = 1
# Codec loss weights.
# MSE weight: 0 = rely on CE loss alone (recommended for joint training).
CODEC_MSE_WEIGHT = 0.0
# KL weight: regularises compressed latent towards N(0, I).
CODEC_KL_WEIGHT = 0.001
# Codec learning rate (typically higher than main LoRA LR since it starts from scratch).
CODEC_LR = 1e-4
# Per-dimension ReZero gate for codec blocks.
# True = alpha shape (D,): each feature dimension has its own gate (more expressive).
# False = alpha shape (1,): single scalar gate shared across all dimensions (original behavior).
CODEC_PER_DIM_ALPHA = True
# False = joint mode: LoRA adapters + codec both train.
# True = codec-only mode: full LLM frozen, only codec trains.
CODEC_ONLY = False
# Dataset configuration
# DATASET_NAME = "the_pile" # Will use a subset of The Pile
# DATASET_NAME = "c4_realnewslike" # C4 news domains - high quality; stratified gets short+long
DATASET_NAME = "arkanathp/1M-stratigied-pile-uncopyrighted,arkanathp/400K-stratified-c4-realnewslike" # C4 news domains - high quality; stratified gets short+long
# DATASET_NAME = "arkanathp/400K-straitified-pile-uncopyrighted"
# DATASET_NAME = "bayes-group-diffusion/wikipedia" # Use DATASET_TEXT_COLUMN = "text_trg"
# DATASET_NAME = "arkanathp/400K-stratified-c4-realnewslike" # Private HF dataset (text column)
# DATASET_NAME = "dummy"
DATASET_SUBSET = None # For c4_realnewslike, None uses "realnewslike" automatically
DATASET_SPLIT = "train"
MAX_SAMPLES = 1_400_000 # total cap across all datasets
DATASET_TEXT_COLUMN = "text" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
# DATASET_TEXT_COLUMN = "text_trg" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
# Pile source filter (only for monology/pile-uncopyrighted).
# Pile has meta.pile_set_name: "Pile-CC", "Github", "StackExchange", "Wikipedia (en)", "ArXiv", etc.
# PILE_SOURCE_INCLUDE: if set, only keep these sources (e.g. ["Wikipedia (en)", "StackExchange", "ArXiv"])
# PILE_SOURCE_EXCLUDE: if set, skip these sources (e.g. ["Pile-CC", "Github"] to reduce noise)
PILE_SOURCE_INCLUDE = []
PILE_SOURCE_EXCLUDE = ["Pile-CC", "Github"] # Exclude noisy Common Crawl; set to [] to allow all
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# CHECKPOINT_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
# LOG_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
# Paths
CHECKPOINT_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
LOG_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
# Pre-tokenized cache: None = auto-derive from DATASET_NAME (cache/tokenized_{seq}_{gist}_{dataset}.pkl). "" = disabled.
TOKENIZED_CACHE_PATH = None
# Perplexity eval: held-out set path (local JSON/JSONL or HF hub ID). None = disabled.
PERPLEXITY_TEST_SET_PATH = "arkanathp/800_stratified_heldout_c4_realnewslike"
PERPLEXITY_EVAL_STEPS = 250