File size: 8,978 Bytes
fd352d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
Configuration file for Gemma 3 Gist Token Autoencoder
"""

import torch

# Model configuration
# MODEL_NAME = "google/gemma-3-1b-pt"
MODEL_NAME = "google/gemma-3-270m"
# MODEL_NAME = "google/gemma-3-270m-it"
# MODEL_NAME = "google/gemma-3-1b-it"
NUM_GIST_TOKENS = 256
# GIST_LAYERS = list(range(0, 18))
# gist_layers = [i, j, k, ...]; layer 0 is not in the flag. At least one element, strictly increasing.
GIST_LAYERS = [11]
# GIST_LAYERS = [10, 17]  # 18 layers in 270M model, 26 in 1B model; e.g. [5, 11, 17] for multi refinement

# Training configuration
INPUT_SEQ_LENGTH = 1024  # length of input text to encode
LEARNING_RATE = 5e-5
# LEARNING_RATE = 1e-4
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
STRATIFIED_SAMPLING = True  # Collect data so each 100-char length bucket has ~equal count (no oversampling)
USE_CURRICULUM = True  # Sort dataset by length (short→long) for curriculum training
NUM_EPOCHS = 2

ATTENTION_DROPOUT = 0.0

# Latent space augmentation (gist hidden states) during training.
# Gaussian noise with this std is added to gist latent representations before reconstruction.
# Helps the model learn robust representations. Set to 0 to disable.
LATENT_NOISE_STD = 0.1

# L1 regularization coefficient for gist token activations.
# Penalizes the mean absolute value of gist hidden states to encourage sparse latent representations.
# Set to 0 to disable.
GIST_L1_COEFF = 0.0

# KL divergence coefficient for gist activations toward N(0, I).
# Penalizes deviation of gist hidden states from standard normal (per dimension).
# Set to 0 to disable.
GIST_KL_COEFF = 0.001

# KL divergence coefficient against the frozen pre-trained reference model (RLHF-style).
# Penalizes KL(current || reference) on output token distributions to prevent drift.
# Reference model runs a standard causal LM forward on just the output tokens (no gist).
# Set to 0 to disable (reference model will not be loaded).
REF_MODEL_KL_COEFF = 0.0001

# Track ref KL divergence even when REF_MODEL_KL_COEFF=0 (loads ref model but does not add to loss).
# Useful for monitoring drift without penalizing it. Ignored when REF_MODEL_KL_COEFF > 0.
TRACK_REF_KL = True

# Randomly mask out this fraction of gist token positions from decoder attention during training.
# Each gist position is independently masked with this probability (attention set to -inf).
LATENT_GIST_MASK_PROB = 0

# Randomly replace this fraction of input (encoder) tokens with mask token during training.
# Only applies to content tokens in the input segment. Helps robustness. 0 = disable.
INPUT_MASK_PROB = 0

# Randomly shuffle sentences in the input during training. Fraction of examples to apply.
# Uses PySBD for robust sentence detection. 0 = disable. When > 0, pre-tokenized cache is disabled.
SHUFFLE_SENTENCES_PROB = 0.0 # just shuffle all the time

# Experiment: if True, only the first layer after the first gist layer reads gist activations.
# Layers >= (gist_layer_i + 2) revert to isolated-output masking.
ISOLATE_ABOVE_GIST = False

# Fine-tune embed only: train only token embeddings, freeze rest.
FINE_TUNE_EMBED_ONLY = False
# Module to train (usually model.embed_tokens)
EMBED_MODULE_TO_TRAIN = "model.embed_tokens"
# When True, only train gist token rows (~524K params). When False, train full embed table (~168M).
EMBED_GIST_TOKENS_ONLY = True

# Initialize gist token embeddings from phrase (when FINE_TUNE_EMBED_ONLY)
INIT_GIST_FROM_PHRASE = False
GIST_INIT_PHRASE = "remember what I've seen and repeat after me"

# LoRA configuration (when USE_LORA=True, FINE_TUNE_EMBED_ONLY=False)
USE_LORA = True
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = "q_proj,v_proj,k_proj,o_proj"
LORA_MODULES_TO_SAVE = "model.embed_tokens"
# LORA_MODULES_TO_SAVE = ""
# When True and model.embed_tokens is in LORA_MODULES_TO_SAVE, only train gist token rows
# via a gradient mask hook (same trick as EMBED_GIST_TOKENS_ONLY for fine_tune_embed_only).
LORA_EMBED_GIST_TOKENS_ONLY = True
# When True, apply LoRA only to layers 0..gist_layer_i (encoder side). Layers above the gist
# layer are frozen, halving the number of trainable LoRA parameters.
LORA_ENCODER_ONLY = False
# When True, zero out the LoRA delta for all output token positions (positions >=
# input_seq_length + num_gist_tokens) by hooking each LoRA B-matrix. Input and gist
# positions still use LoRA; the output reconstruction uses only frozen base weights.
LORA_OUTPUT_TOKENS_NO_LORA = False

# Layer-freeze mode (alternative to LoRA and fine_tune_embed_only).
# When set, freezes all parameters then unfreezes only the listed transformer layers
# plus gist token embeddings (via gradient mask). Set to None to disable.
# E.g. "10,11" → unfreeze the layer feeding into gist (10) and the gist layer (11).
TRAIN_LAYERS = None
# TRAIN_LAYERS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13"

# ---------------------------------------------------------------------------
# Inline Gist Codec — joint LoRA + codec training
# ---------------------------------------------------------------------------
# Pre-trained gist checkpoint to warm-start from (PEFT adapter directory).
# Adapter weights are loaded onto the base model; training starts fresh
# (no optimizer/scheduler state from the prior run).
# Set to None to start from the raw base model.
INIT_FROM_CHECKPOINT = "arkanathp/s1_curriculum_true_latent_0_1_with_gist_kl_0_0001_ref_kl_0_0001"
# INIT_FROM_CHECKPOINT = "arkanathp/s2_noise_0_1_latent_kl_0_01_ref_kl_0_01"
# INIT_FROM_CHECKPOINT = None

# If True, load INIT_FROM_CHECKPOINT as the reference model instead of the raw base model_name.
# Measures KL drift from the warm-start adapter (stage-1 checkpoint) rather than from raw pretrained weights.
REF_MODEL_FROM_CHECKPOINT = False


# Codec enable flag: 0 = disabled, any value > 0 = enabled.
# The codec is additive (identity at init); no token/dim compression.
CODEC_K = 0

# Self-attention depth for encoder and decoder blocks of the codec.
CODEC_ENCODE_LAYERS = 1
CODEC_DECODE_LAYERS = 1

# Codec loss weights.
# MSE weight: 0 = rely on CE loss alone (recommended for joint training).
CODEC_MSE_WEIGHT = 0.0
# KL weight: regularises compressed latent towards N(0, I).
CODEC_KL_WEIGHT = 0.001

# Codec learning rate (typically higher than main LoRA LR since it starts from scratch).
CODEC_LR = 1e-4

# Per-dimension ReZero gate for codec blocks.
# True  = alpha shape (D,): each feature dimension has its own gate (more expressive).
# False = alpha shape (1,): single scalar gate shared across all dimensions (original behavior).
CODEC_PER_DIM_ALPHA = True

# False = joint mode: LoRA adapters + codec both train.
# True  = codec-only mode: full LLM frozen, only codec trains.
CODEC_ONLY = False

# Dataset configuration
# DATASET_NAME = "the_pile"  # Will use a subset of The Pile
# DATASET_NAME = "c4_realnewslike"  # C4 news domains - high quality; stratified gets short+long
DATASET_NAME = "arkanathp/1M-stratigied-pile-uncopyrighted,arkanathp/400K-stratified-c4-realnewslike"  # C4 news domains - high quality; stratified gets short+long
# DATASET_NAME = "arkanathp/400K-straitified-pile-uncopyrighted"
# DATASET_NAME = "bayes-group-diffusion/wikipedia"  # Use DATASET_TEXT_COLUMN = "text_trg"
# DATASET_NAME = "arkanathp/400K-stratified-c4-realnewslike"  # Private HF dataset (text column)
# DATASET_NAME = "dummy"
DATASET_SUBSET = None  # For c4_realnewslike, None uses "realnewslike" automatically
DATASET_SPLIT = "train"
MAX_SAMPLES = 1_400_000  # total cap across all datasets
DATASET_TEXT_COLUMN = "text"  # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia
# DATASET_TEXT_COLUMN = "text_trg"  # Column name for text. Use "text_trg" for bayes-group-diffusion/wikipedia

# Pile source filter (only for monology/pile-uncopyrighted).
# Pile has meta.pile_set_name: "Pile-CC", "Github", "StackExchange", "Wikipedia (en)", "ArXiv", etc.
# PILE_SOURCE_INCLUDE: if set, only keep these sources (e.g. ["Wikipedia (en)", "StackExchange", "ArXiv"])
# PILE_SOURCE_EXCLUDE: if set, skip these sources (e.g. ["Pile-CC", "Github"] to reduce noise)
PILE_SOURCE_INCLUDE = []
PILE_SOURCE_EXCLUDE = ["Pile-CC", "Github"]  # Exclude noisy Common Crawl; set to [] to allow all

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# CHECKPOINT_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
# LOG_DIR = "./checkpoints_lora_s1/s1_curriculum_true_finetune_all_selected_layers_kl_0_0001_ref_kl_0_0001"
# Paths
CHECKPOINT_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
LOG_DIR = "./checkpoints_lora_s2_ablation/s2_no_codec"
# Pre-tokenized cache: None = auto-derive from DATASET_NAME (cache/tokenized_{seq}_{gist}_{dataset}.pkl). "" = disabled.
TOKENIZED_CACHE_PATH = None

# Perplexity eval: held-out set path (local JSON/JSONL or HF hub ID). None = disabled.
PERPLEXITY_TEST_SET_PATH = "arkanathp/800_stratified_heldout_c4_realnewslike"
PERPLEXITY_EVAL_STEPS = 250