arkanathp commited on
Commit
fd352d5
·
verified ·
1 Parent(s): 77ceaa2

Upload config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. config.py +184 -0
config.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file for Gemma 3 Gist Token Autoencoder
3
+ """
4
+
5
+ import torch
6
+
7
+ # Model configuration
8
+ # MODEL_NAME = "google/gemma-3-1b-pt"
9
+ MODEL_NAME = "google/gemma-3-270m"
10
+ # MODEL_NAME = "google/gemma-3-270m-it"
11
+ # MODEL_NAME = "google/gemma-3-1b-it"
12
+ NUM_GIST_TOKENS = 256
13
+ # GIST_LAYERS = list(range(0, 18))
14
+ # gist_layers = [i, j, k, ...]; layer 0 is not in the flag. At least one element, strictly increasing.
15
+ GIST_LAYERS = [11]
16
+ # GIST_LAYERS = [10, 17] # 18 layers in 270M model, 26 in 1B model; e.g. [5, 11, 17] for multi refinement
17
+
18
+ # Training configuration
19
+ INPUT_SEQ_LENGTH = 1024 # length of input text to encode
20
+ LEARNING_RATE = 5e-5
21
+ # LEARNING_RATE = 1e-4
22
+ BATCH_SIZE = 4
23
+ GRADIENT_ACCUMULATION_STEPS = 4
24
+ STRATIFIED_SAMPLING = True # Collect data so each 100-char length bucket has ~equal count (no oversampling)
25
+ USE_CURRICULUM = True # Sort dataset by length (short→long) for curriculum training
26
+ NUM_EPOCHS = 2
27
+
28
+ ATTENTION_DROPOUT = 0.0
29
+
30
+ # Latent space augmentation (gist hidden states) during training.
31
+ # Gaussian noise with this std is added to gist latent representations before reconstruction.
32
+ # Helps the model learn robust representations. Set to 0 to disable.
33
+ LATENT_NOISE_STD = 0.1
34
+
35
+ # L1 regularization coefficient for gist token activations.
36
+ # Penalizes the mean absolute value of gist hidden states to encourage sparse latent representations.
37
+ # Set to 0 to disable.
38
+ GIST_L1_COEFF = 0.0
39
+
40
+ # KL divergence coefficient for gist activations toward N(0, I).
41
+ # Penalizes deviation of gist hidden states from standard normal (per dimension).
42
+ # Set to 0 to disable.
43
+ GIST_KL_COEFF = 0.001
44
+
45
+ # KL divergence coefficient against the frozen pre-trained reference model (RLHF-style).
46
+ # Penalizes KL(current || reference) on output token distributions to prevent drift.
47
+ # Reference model runs a standard causal LM forward on just the output tokens (no gist).
48
+ # Set to 0 to disable (reference model will not be loaded).
49
+ REF_MODEL_KL_COEFF = 0.0001
50
+
51
+ # Track ref KL divergence even when REF_MODEL_KL_COEFF=0 (loads ref model but does not add to loss).
52
+ # Useful for monitoring drift without penalizing it. Ignored when REF_MODEL_KL_COEFF > 0.
53
+ TRACK_REF_KL = True
54
+
55
+ # Randomly mask out this fraction of gist token positions from decoder attention during training.
56
+ # Each gist position is independently masked with this probability (attention set to -inf).
57
+ LATENT_GIST_MASK_PROB = 0
58
+
59
+ # Randomly replace this fraction of input (encoder) tokens with mask token during training.
60
+ # Only applies to content tokens in the input segment. Helps robustness. 0 = disable.
61
+ INPUT_MASK_PROB = 0
62
+
63
+ # Randomly shuffle sentences in the input during training. Fraction of examples to apply.
64
+ # Uses PySBD for robust sentence detection. 0 = disable. When > 0, pre-tokenized cache is disabled.
65
+ SHUFFLE_SENTENCES_PROB = 0.0 # just shuffle all the time
66
+
67
+ # Experiment: if True, only the first layer after the first gist layer reads gist activations.
68
+ # Layers >= (gist_layer_i + 2) revert to isolated-output masking.
69
+ ISOLATE_ABOVE_GIST = False
70
+
71
+ # Fine-tune embed only: train only token embeddings, freeze rest.
72
+ FINE_TUNE_EMBED_ONLY = False
73
+ # Module to train (usually model.embed_tokens)
74
+ EMBED_MODULE_TO_TRAIN = "model.embed_tokens"
75
+ # When True, only train gist token rows (~524K params). When False, train full embed table (~168M).
76
+ EMBED_GIST_TOKENS_ONLY = True
77
+
78
+ # Initialize gist token embeddings from phrase (when FINE_TUNE_EMBED_ONLY)
79
+ INIT_GIST_FROM_PHRASE = False
80
+ GIST_INIT_PHRASE = "remember what I've seen and repeat after me"
81
+
82
+ # LoRA configuration (when USE_LORA=True, FINE_TUNE_EMBED_ONLY=False)
83
+ USE_LORA = True
84
+ LORA_R = 8
85
+ LORA_ALPHA = 16
86
+ LORA_DROPOUT = 0.05
87
+ LORA_TARGET_MODULES = "q_proj,v_proj,k_proj,o_proj"
88
+ LORA_MODULES_TO_SAVE = "model.embed_tokens"
89
+ # LORA_MODULES_TO_SAVE = ""
90
+ # When True and model.embed_tokens is in LORA_MODULES_TO_SAVE, only train gist token rows
91
+ # via a gradient mask hook (same trick as EMBED_GIST_TOKENS_ONLY for fine_tune_embed_only).
92
+ LORA_EMBED_GIST_TOKENS_ONLY = True
93
+ # When True, apply LoRA only to layers 0..gist_layer_i (encoder side). Layers above the gist
94
+ # layer are frozen, halving the number of trainable LoRA parameters.
95
+ LORA_ENCODER_ONLY = False
96
+ # When True, zero out the LoRA delta for all output token positions (positions >=
97
+ # input_seq_length + num_gist_tokens) by hooking each LoRA B-matrix. Input and gist
98
+ # positions still use LoRA; the output reconstruction uses only frozen base weights.
99
+ LORA_OUTPUT_TOKENS_NO_LORA = False
100
+
101
+ # Layer-freeze mode (alternative to LoRA and fine_tune_embed_only).
102
+ # When set, freezes all parameters then unfreezes only the listed transformer layers
103
+ # plus gist token embeddings (via gradient mask). Set to None to disable.
104
+ # E.g. "10,11" → unfreeze the layer feeding into gist (10) and the gist layer (11).
105
+ TRAIN_LAYERS = None
106
+ # TRAIN_LAYERS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13"
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Inline Gist Codec — joint LoRA + codec training
110
+ # ---------------------------------------------------------------------------
111
+ # Pre-trained gist checkpoint to warm-start from (PEFT adapter directory).
112
+ # Adapter weights are loaded onto the base model; training starts fresh
113
+ # (no optimizer/scheduler state from the prior run).
114
+ # Set to None to start from the raw base model.
115
+ INIT_FROM_CHECKPOINT = "arkanathp/s1_curriculum_true_latent_0_1_with_gist_kl_0_0001_ref_kl_0_0001"
116
+ # INIT_FROM_CHECKPOINT = "arkanathp/s2_noise_0_1_latent_kl_0_01_ref_kl_0_01"
117
+ # INIT_FROM_CHECKPOINT = None
118
+
119
+ # If True, load INIT_FROM_CHECKPOINT as the reference model instead of the raw base model_name.
120
+ # Measures KL drift from the warm-start adapter (stage-1 checkpoint) rather than from raw pretrained weights.
121
+ REF_MODEL_FROM_CHECKPOINT = False
122
+
123
+
124
+ # Codec enable flag: 0 = disabled, any value > 0 = enabled.
125
+ # The codec is additive (identity at init); no token/dim compression.
126
+ CODEC_K = 0
127
+
128
+ # Self-attention depth for encoder and decoder blocks of the codec.
129
+ CODEC_ENCODE_LAYERS = 1
130
+ CODEC_DECODE_LAYERS = 1
131
+
132
+ # Codec loss weights.
133
+ # MSE weight: 0 = rely on CE loss alone (recommended for joint training).
134
+ CODEC_MSE_WEIGHT = 0.0
135
+ # KL weight: regularises compressed latent towards N(0, I).
136
+ CODEC_KL_WEIGHT = 0.001
137
+
138
+ # Codec learning rate (typically higher than main LoRA LR since it starts from scratch).
139
+ CODEC_LR = 1e-4
140
+
141
+ # Per-dimension ReZero gate for codec blocks.
142
+ # True = alpha shape (D,): each feature dimension has its own gate (more expressive).
143
+ # False = alpha shape (1,): single scalar gate shared across all dimensions (original behavior).
144
+ CODEC_PER_DIM_ALPHA = True
145
+
146
+ # False = joint mode: LoRA adapters + codec both train.
147
+ # True = codec-only mode: full LLM frozen, only codec trains.
148
+ CODEC_ONLY = False
149
+
150
+ # Dataset configuration
151
+ # DATASET_NAME = "the_pile" # Will use a subset of The Pile
152
+ # DATASET_NAME = "c4_realnewslike" # C4 news domains - high quality; stratified gets short+long
153
+ DATASET_NAME = "arkanathp/1M-stratigied-pile-uncopyrighted,arkanathp/400K-stratified-c4-realnewslike" # C4 news domains - high quality; stratified gets short+long
154
+ # DATASET_NAME = "arkanathp/400K-straitified-pile-uncopyrighted"
155
+ # DATASET_NAME = "bayes-group-diffusion/wikipedia" # Use DATASET_TEXT_COLUMN = "text_trg"
156
+ # DATASET_NAME = "arkanathp/400K-stratified-c4-realnewslike" # Private HF dataset (text column)
157
+ # DATASET_NAME = "dummy"
158
+ DATASET_SUBSET = None # For c4_realnewslike, None uses "realnewslike" automatically
159
+ DATASET_SPLIT = "train"
160
+ MAX_SAMPLES = 1_400_000 # total cap across all datasets
161
+ DATASET_TEXT_COLUMN = "text" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
162
+ # DATASET_TEXT_COLUMN = "text_trg" # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
163
+
164
+ # Pile source filter (only for monology/pile-uncopyrighted).
165
+ # Pile has meta.pile_set_name: "Pile-CC", "Github", "StackExchange", "Wikipedia (en)", "ArXiv", etc.
166
+ # PILE_SOURCE_INCLUDE: if set, only keep these sources (e.g. ["Wikipedia (en)", "StackExchange", "ArXiv"])
167
+ # PILE_SOURCE_EXCLUDE: if set, skip these sources (e.g. ["Pile-CC", "Github"] to reduce noise)
168
+ PILE_SOURCE_INCLUDE = []
169
+ PILE_SOURCE_EXCLUDE = ["Pile-CC", "Github"] # Exclude noisy Common Crawl; set to [] to allow all
170
+
171
+ # Device configuration
172
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
173
+
174
+ # CHECKPOINT_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
175
+ # LOG_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
176
+ # Paths
177
+ CHECKPOINT_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
178
+ LOG_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
179
+ # Pre-tokenized cache: None = auto-derive from DATASET_NAME (cache/tokenized_{seq}_{gist}_{dataset}.pkl). "" = disabled.
180
+ TOKENIZED_CACHE_PATH = None
181
+
182
+ # Perplexity eval: held-out set path (local JSON/JSONL or HF hub ID). None = disabled.
183
+ PERPLEXITY_TEST_SET_PATH = "arkanathp/800_stratified_heldout_c4_realnewslike"
184
+ PERPLEXITY_EVAL_STEPS = 250