lucky_pick_scheduler
Community Article Published
April 15, 2026
import modal
import os
app = modal.App("")
image = (
modal.Image.from_registry(
"nvidia/cuda:12.1.1-devel-ubuntu22.04",
add_python="3.11",
)
.apt_install("git", "gcc", "g++")
.pip_install(
"torch==2.5.1",
"torchvision==0.20.1",
"torchaudio==2.5.1",
index_url="https://download.pytorch.org/whl/cu121",
)
.pip_install("unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git")
.pip_install(
"unsloth_zoo",
"trl",
"peft",
"accelerate",
"bitsandbytes",
"datasets",
"wandb",
"protobuf",
)
.run_commands("pip uninstall -y torchao")
)
volume = modal.Volume.from_name("bella-v6-deep-chaos-v3", create_if_missing=True)
CONFIG = dict(
MODEL_NAME="",
MAX_SEQ_LENGTH=,
LOAD_IN_4BIT=,
MODEL_DTYPE="",
DATASET_NAME="",
DATASET_SPLIT="train",
INJECTOR_ENABLED=,
INJECTOR_LAYERS=[6, 12, 18],
INJECTOR_HIDDEN_DIM=128,
INJECTOR_NOISE_SCALE=0.0,
DECORRELATION_WEIGHT=0.0,
PROBE_ALIGNMENT_WEIGHT=0.0,
PROBE_ALIGNMENT_EVERY=1,
ORTHO_PEAK_LAYERS=[4, 12, 21],
ORTHO_LOSS_WEIGHT=0.0,
ORTHO_ON_PRE_INJECTION=True,
ORTHO_METRIC="token_wise",
STABILITY_LOSS_WEIGHT=0.0,
STABILITY_ON_POST_INJECTION=True,
PRESERVATION_KL_WEIGHT=0.0,
LAYER_FREEZE_FRACTION=0.0,
TRAIN_LAYER_RANGE=(0, 27),
DEEP_CHAOS_ENABLED=True,
DEEP_CHAOS_SACRED_LAYERS=[0, 1, 26, 27],
DEEP_CHAOS_VICTIM_RANGE=(2, 26),
DEEP_CHAOS_MIN_LAYER_SURVIVAL=0.30,
DEEP_CHAOS_MAX_LAYER_SURVIVAL=0.70,
DEEP_CHAOS_MIN_HEAD_SURVIVAL=0.30,
DEEP_CHAOS_MAX_HEAD_SURVIVAL=0.70,
DEEP_CHAOS_CHANNEL_GROUP_SIZE=128,
DEEP_CHAOS_MIN_CHANNEL_SURVIVAL=0.30,
DEEP_CHAOS_MAX_CHANNEL_SURVIVAL=0.70,
DEEP_CHAOS_MLP_GATE_GROUP_SIZE=128,
DEEP_CHAOS_MIN_MLP_GATE_SURVIVAL=0.35,
DEEP_CHAOS_MAX_MLP_GATE_SURVIVAL=0.80,
DEEP_CHAOS_HIDDEN_GROUP_SIZE=64,
DEEP_CHAOS_MIN_HIDDEN_SURVIVAL=0.60,
DEEP_CHAOS_MAX_HIDDEN_SURVIVAL=0.95,
DEEP_CHAOS_MAX_CONSECUTIVE_ON=5,
DEEP_CHAOS_MAX_CONSECUTIVE_OFF=10,
DEEP_CHAOS_STICKY_INTERVAL=50,
DEEP_CHAOS_SEED=42,
NUM_EPOCHS=,
BATCH_SIZE=2,
GRADIENT_ACCUMULATION=2, #2x2 works best for me dont know why
LEARNING_RATE=2e-5,
LR_SCHEDULER="cosine",
WARMUP_RATIO=0.15,
WEIGHT_DECAY=0.01,
MAX_GRAD_NORM=1.0,
LOGGING_STEPS=10,
SAVE_STEPS=2000,
PROBE_STEPS=100,
OUTPUT_DIR="",
HF_REPO="",
WANDB_PROJECT="",
WANDB_RUN_NAME="",
PROBE_ENABLED=True,
PROBE_LAYER_RANGE=(2, 26),
)
@app.function(
image=image,
gpu="",
timeout=6 * 3600,
volumes={"/vol": volume},
secrets=[
modal.Secret.from_name(""),
modal.Secret.from_name(""),
],
)
def train(config: dict = CONFIG):
import copy
import json
import math
import random
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from datasets import load_dataset
from huggingface_hub import HfApi
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template, standardize_sharegpt
from transformers import DataCollatorForSeq2Seq, Trainer, TrainerCallback, TrainingArguments
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
C = type("C", (), config)()
def resolve_dtype(name):
return {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}[name]
def resolve_transformer_layers(target_model):
if hasattr(target_model, "model") and hasattr(target_model.model, "layers"):
return target_model.model.layers
if hasattr(target_model, "base_model"):
base = target_model.base_model
if hasattr(base, "model") and hasattr(base.model, "layers"):
return base.model.layers
if hasattr(base, "model") and hasattr(base.model, "model") and hasattr(base.model.model, "layers"):
return base.model.model.layers
raise AttributeError("Could not find transformer layers")
def resolve_transformer_root(target_model):
if hasattr(target_model, "model") and hasattr(target_model.model, "layers"):
return target_model.model
if hasattr(target_model, "base_model"):
base = target_model.base_model
if hasattr(base, "model") and hasattr(base.model, "layers"):
return base.model
if hasattr(base, "model") and hasattr(base.model, "model") and hasattr(base.model.model, "layers"):
return base.model.model
raise AttributeError("Could not find transformer root")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"PyTorch: {torch.__version__} CUDA: {torch.version.cuda}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=C.MODEL_NAME,
max_seq_length=C.MAX_SEQ_LENGTH,
dtype=resolve_dtype(C.MODEL_DTYPE),
load_in_4bit=C.LOAD_IN_4BIT,
)
model.enable_input_require_grads()
model_device = next(model.parameters()).device
print(f"Model loaded: {C.MODEL_NAME} VRAM: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
def setup_layer_freezing(target_model):
layers = resolve_transformer_layers(target_model)
total = len(layers)
train_start, train_end = C.TRAIN_LAYER_RANGE if C.TRAIN_LAYER_RANGE else (0, total - 1)
train_start = max(0, min(train_start, total - 1))
train_end = max(train_start, min(train_end, total - 1))
frozen = 0
trainable = 0
for name, param in target_model.named_parameters():
should_train = True
if "embed" in name or "lm_head" in name:
should_train = False
elif "layers." in name:
layer_idx = int(name.split("layers.")[1].split(".")[0])
if layer_idx < train_start or layer_idx > train_end:
should_train = False
param.requires_grad = should_train
if should_train:
trainable += 1
else:
frozen += 1
print(f"Freeze: train layers {train_start}-{train_end}/{total - 1} frozen={frozen} trainable={trainable}")
return {"total_layers": total, "train_start": train_start, "train_end": train_end}
freezing_info = setup_layer_freezing(model)
base_model_ref = None
if C.PRESERVATION_KL_WEIGHT > 0:
base_model_ref = copy.deepcopy(model)
base_model_ref.eval()
for param in base_model_ref.parameters():
param.requires_grad = False
class EntropyInjector(nn.Module):
def __init__(self, hidden_dim, injector_hidden_dim, noise_scale):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden_dim, injector_hidden_dim, bias=False),
nn.GELU(),
nn.Linear(injector_hidden_dim, hidden_dim, bias=False),
)
self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.noise_scale = noise_scale
nn.init.orthogonal_(self.net[0].weight, gain=math.sqrt(2))
nn.init.zeros_(self.net[2].weight)
def forward(self, hidden_states, assistant_mask=None):
dtype = hidden_states.dtype
perturb = self.norm(self.net(hidden_states))
perturb = torch.nan_to_num(perturb * self.noise_scale, nan=0.0).clamp(-1, 1)
if assistant_mask is not None:
perturb = perturb * assistant_mask.to(dtype).unsqueeze(-1)
return torch.nan_to_num(hidden_states + perturb, nan=0.0).to(dtype)
class InjectorManager:
def __init__(self, target_model, layers, hidden_dim, injector_hidden_dim, noise_scale, ortho_weight, ortho_on_pre, ortho_metric, device):
self.model = target_model
self.injector_layers = layers
self.ortho_loss_weight = ortho_weight
self.ortho_on_pre = ortho_on_pre
self.ortho_metric = ortho_metric
self.injectors = nn.ModuleDict()
self.hooks = []
self.cached_pre = {}
self.cached_post = {}
self.current_assistant_mask = None
self.enabled_for_probe = False
for layer_idx in layers:
inj = EntropyInjector(hidden_dim, injector_hidden_dim, noise_scale).to(
device=device,
dtype=next(target_model.parameters()).dtype,
)
self.injectors[str(layer_idx)] = inj
self._install()
def _install(self):
layers = resolve_transformer_layers(self.model)
for layer_idx in self.injector_layers:
self.hooks.append(layers[layer_idx].register_forward_hook(self._make_hook(layer_idx)))
def _make_hook(self, layer_idx):
injector = self.injectors[str(layer_idx)]
def hook_fn(module, inputs, output):
if not self.model.training and not self.enabled_for_probe:
return output
if isinstance(output, tuple):
hidden = output[0]
self.cached_pre[layer_idx] = hidden
result = injector(hidden, self.current_assistant_mask)
if not torch.isfinite(result).all():
result = hidden
self.cached_post[layer_idx] = result
return (result,) + output[1:]
self.cached_pre[layer_idx] = output
result = injector(output, self.current_assistant_mask)
if not torch.isfinite(result).all():
result = output
self.cached_post[layer_idx] = result
return result
return hook_fn
def set_noise_scale(self, noise_scale):
for injector in self.injectors.values():
injector.noise_scale = noise_scale
def compute_ortho_loss(self, use_post=False):
cache = self.cached_post if use_post else self.cached_pre
if len(cache) < 2:
return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
loss = torch.zeros(1, device=next(self.injectors.parameters()).device)
layers = sorted(cache.keys())
for i in range(len(layers)):
for j in range(i + 1, len(layers)):
a = cache[layers[i]]
b = cache[layers[j]]
if self.ortho_metric == "token_wise":
sim = F.cosine_similarity(a.view(-1, a.size(-1)), b.view(-1, b.size(-1)), dim=-1)
loss = loss + sim.square().mean()
else:
sim = F.cosine_similarity(a.mean((0, 1)).unsqueeze(0), b.mean((0, 1)).unsqueeze(0)).squeeze()
loss = loss + sim.square()
return torch.nan_to_num(loss * self.ortho_loss_weight).clamp_max(10.0).squeeze()
def snapshot_cache(self, use_post=False, detach=False):
cache = self.cached_post if use_post else self.cached_pre
return {k: v.detach() for k, v in cache.items()} if detach else dict(cache)
def compute_stability_loss(self, clean, perturbed, mask=None):
if not clean or not perturbed:
return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
losses = []
for layer_idx in sorted(set(clean) & set(perturbed)):
a = clean[layer_idx].to(perturbed[layer_idx].dtype)
b = perturbed[layer_idx]
if not torch.isfinite(a).all() or not torch.isfinite(b).all():
continue
diff = (b - a).pow(2)
if mask is not None:
m = mask.to(diff.dtype).unsqueeze(-1)
denom = m.sum() * diff.size(-1)
if denom.item() > 0:
losses.append((diff * m).sum() / denom)
continue
losses.append(diff.mean())
return torch.stack(losses).mean() if losses else torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
def compute_decorrelation_loss(self, mask=None, use_post=True):
cache = self.cached_post if use_post else self.cached_pre
layers = sorted(cache.keys())
if len(layers) < 2:
return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
pooled = {}
for layer_idx in layers:
hidden = cache[layer_idx]
if not torch.isfinite(hidden).all():
continue
hidden_f = hidden.float()
if mask is not None:
m = mask.to(hidden_f.dtype).unsqueeze(-1)
pooled[layer_idx] = (hidden_f * m).sum(1) / m.sum(1).clamp_min(1)
else:
pooled[layer_idx] = hidden_f.mean(1)
valid = sorted(pooled.keys())
if len(valid) < 2:
return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
terms = []
for i in range(len(valid)):
for j in range(i + 1, len(valid)):
terms.append(F.cosine_similarity(pooled[valid[i]], pooled[valid[j]], dim=-1).square().mean())
return torch.stack(terms).mean() if terms else torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
def clear_cache(self):
self.cached_pre.clear()
self.cached_post.clear()
self.current_assistant_mask = None
def get_trainable_params(self):
return list(self.injectors.parameters())
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks.clear()
injector_mgr = None
if C.INJECTOR_ENABLED and C.INJECTOR_LAYERS:
injector_mgr = InjectorManager(
model,
C.INJECTOR_LAYERS,
model.config.hidden_size,
C.INJECTOR_HIDDEN_DIM,
C.INJECTOR_NOISE_SCALE,
C.ORTHO_LOSS_WEIGHT,
C.ORTHO_ON_PRE_INJECTION,
C.ORTHO_METRIC,
model_device,
)
class SlicedLinear:
@staticmethod
def forward(linear, x, alive_out=None, alive_in=None, compact=False):
weight, bias = linear.weight, linear.bias
full_out = weight.shape[0]
batch_shape = x.shape[:-1]
if alive_out is None and alive_in is None:
return F.linear(x, weight, bias)
w = weight
b = bias
x_in = x
if alive_out is not None:
w = w[alive_out]
b = b[alive_out] if b is not None else None
if alive_in is not None:
w = w[:, alive_in]
x_in = x[..., alive_in]
partial = F.linear(x_in, w, b)
if compact or alive_out is None:
return partial
if alive_out.shape[0] < full_out:
out = x.new_zeros(*batch_shape, full_out)
out[..., alive_out] = partial
return out
return partial
@dataclass
class LayerTopology:
mode: str = "both"
alive_q_heads: list | None = None
alive_kv_heads: list | None = None
alive_q_out: torch.Tensor | None = None
alive_k_out: torch.Tensor | None = None
alive_v_out: torch.Tensor | None = None
alive_o_out: torch.Tensor | None = None
n_alive_q_heads: int = 0
n_alive_kv_heads: int = 0
alive_gate_out: torch.Tensor | None = None
alive_up_out: torch.Tensor | None = None
alive_down_in: torch.Tensor | None = None
alive_down_out: torch.Tensor | None = None
class DeepChaosScheduler:
def __init__(
self,
target_model,
sacred_layers,
victim_range,
min_layer_survival,
max_layer_survival,
min_head_survival,
max_head_survival,
channel_group_size,
min_channel_survival,
max_channel_survival,
mlp_gate_group_size,
min_mlp_gate_survival,
max_mlp_gate_survival,
hidden_group_size,
min_hidden_survival,
max_hidden_survival,
max_consecutive_on,
max_consecutive_off,
sticky_interval,
seed,
):
self.model = target_model
self.sacred = set(sacred_layers)
self.victims = list(range(victim_range[0], victim_range[1]))
self.min_ls = min_layer_survival
self.max_ls = max_layer_survival
self.min_hs = min_head_survival
self.max_hs = max_head_survival
self.cgs = channel_group_size
self.min_cs = min_channel_survival
self.max_cs = max_channel_survival
self.mgg = mlp_gate_group_size
self.min_mgs = min_mlp_gate_survival
self.max_mgs = max_mlp_gate_survival
self.hgg = hidden_group_size
self.min_hds = min_hidden_survival
self.max_hds = max_hidden_survival
self.max_on = max_consecutive_on
self.max_off = max_consecutive_off
self.sticky_interval = max(1, int(sticky_interval))
self.seed = seed
self.on_streak = {layer: 0 for layer in self.victims}
self.off_streak = {layer: 0 for layer in self.victims}
self.topologies = {}
self.active_layers = set(self.sacred)
self.last_shuffle_step = None
self.cached_stats = None
layers = resolve_transformer_layers(target_model)
sample = layers[self.victims[0]]
self.hs = sample.self_attn.q_proj.weight.shape[1]
self.q_dim = sample.self_attn.q_proj.weight.shape[0]
self.k_dim = sample.self_attn.k_proj.weight.shape[0]
self.v_dim = sample.self_attn.v_proj.weight.shape[0]
self.inter = sample.mlp.gate_proj.weight.shape[0]
cfg = getattr(target_model, "config", None)
self.nh = getattr(cfg, "num_attention_heads", None)
self.nkv = getattr(cfg, "num_key_value_heads", None)
self.hd = getattr(cfg, "head_dim", None)
if self.nh is None:
self.nh = getattr(sample.self_attn, "num_heads", None)
if self.nkv is None:
self.nkv = getattr(sample.self_attn, "num_key_value_heads", None)
if self.hd is None:
self.hd = getattr(sample.self_attn, "head_dim", None)
if self.nh is None:
raise AttributeError("Could not determine num_attention_heads from model config or attention module")
if self.nkv is None:
self.nkv = self.nh
if self.hd is None:
self.hd = self.hs // self.nh
self.gqa_r = self.nh // self.nkv
self.n_cg = max(1, self.inter // self.cgs)
self.a_cg = self.inter // self.n_cg
self.n_hg = max(1, self.hs // self.hgg)
self.a_hg = self.hs // self.n_hg
self.original_forwards = {}
self._install()
def _install(self):
layers = resolve_transformer_layers(self.model)
for idx in self.victims:
self.original_forwards[idx] = layers[idx].forward
layers[idx].forward = self._make_forward(idx, layers[idx])
print(
f"DeepChaosScheduler v3: victims={self.victims[0]}-{self.victims[-1]} sacred={sorted(self.sacred)}"
)
print(f" heads={self.nh} kv={self.nkv} head_dim={self.hd} hidden={self.hs} inter={self.inter}")
def _sample_groups(self, rng, num_groups, group_size, total_dim, min_rate, max_rate, device):
rate = rng.uniform(min_rate, max_rate)
num_alive = max(1, int(round(num_groups * rate)))
alive_groups = sorted(rng.sample(range(num_groups), num_alive))
indices = []
for group in alive_groups:
start = group * group_size
indices.extend(range(start, min(start + group_size, total_dim)))
return torch.tensor(indices, dtype=torch.long, device=device)
def _heads_to_indices(self, heads, head_dim, device):
indices = []
for head in heads:
indices.extend(range(head * head_dim, (head + 1) * head_dim))
return torch.tensor(indices, dtype=torch.long, device=device)
def _apply_rotary(self, q, k, cos, sin, position_ids=None):
hd = self.hd
seq_len = q.shape[2]
def fix_rope(t):
t = t.float()
if t.shape[-1] != hd:
t = t[..., :hd]
while t.dim() > 2:
t = t.squeeze(0)
if t.shape[0] > seq_len:
if position_ids is not None:
t = t[position_ids]
else:
t = t[:seq_len]
while t.dim() < 3:
t = t.unsqueeze(0)
return t
cos = fix_rope(cos).to(q.dtype)
sin = fix_rope(sin).to(q.dtype)
return apply_rotary_pos_emb(q, k, cos, sin)
def _subsample(self, base, group_size, min_rate, max_rate, rng):
n = len(base)
num_groups = max(1, n // group_size)
if num_groups <= 1:
return base
num_alive = max(1, int(round(num_groups * rng.uniform(min_rate, max_rate))))
alive_groups = sorted(rng.sample(range(num_groups), num_alive))
idx = []
for group in alive_groups:
start = group * group_size
idx.extend(range(start, min(start + group_size, n)))
return base[idx] if isinstance(base, torch.Tensor) else [base[i] for i in idx]
def step(self, global_step):
if (
self.cached_stats is not None
and self.last_shuffle_step is not None
and global_step < self.last_shuffle_step + self.sticky_interval
):
return self.cached_stats
self.last_shuffle_step = global_step
block_idx = global_step // self.sticky_interval
rng = random.Random(self.seed + global_step)
device = next(self.model.parameters()).device
num_victims = len(self.victims)
min_active = max(1, int(round(num_victims * self.min_ls)))
max_active = max(min_active, int(round(num_victims * self.max_ls)))
target = rng.randint(min_active, max_active)
forced_on = [layer for layer in self.victims if self.off_streak[layer] >= self.max_off]
forced_off = [layer for layer in self.victims if self.on_streak[layer] >= self.max_on]
available = [layer for layer in self.victims if layer not in forced_on and layer not in forced_off]
active = set(forced_on)
need = target - len(active)
if need > 0 and available:
active.update(rng.sample(available, min(need, len(available))))
for layer in self.victims:
if layer in active:
self.on_streak[layer] += 1
self.off_streak[layer] = 0
else:
self.on_streak[layer] = 0
self.off_streak[layer] += 1
self.active_layers = active | self.sacred
modes = {}
for layer in self.victims:
if layer not in self.active_layers:
modes[layer] = "dead"
continue
draw = rng.random()
if draw < 0.30:
modes[layer] = "both"
elif draw < 0.55:
modes[layer] = "attn"
elif draw < 0.80:
modes[layer] = "mlp"
else:
modes[layer] = "identity"
self.topologies.clear()
sliced_elements = 0
full_elements = 0
stats = {key: [] for key in ("q", "k", "v", "o", "gate", "up", "down")}
for layer in self.victims:
topo = LayerTopology(mode=modes[layer])
if topo.mode in ("dead", "identity"):
self.topologies[layer] = topo
continue
if topo.mode in ("both", "attn"):
kv_rate = rng.uniform(self.min_hs, self.max_hs)
n_alive_kv = max(1, int(round(self.nkv * kv_rate)))
alive_kv = sorted(rng.sample(range(self.nkv), n_alive_kv))
topo.alive_kv_heads = alive_kv
topo.n_alive_kv_heads = len(alive_kv)
alive_q = []
for kv_head in alive_kv:
for q_head in range(kv_head * self.gqa_r, (kv_head + 1) * self.gqa_r):
if q_head < self.nh:
alive_q.append(q_head)
topo.alive_q_heads = alive_q
topo.n_alive_q_heads = len(alive_q)
topo.alive_q_out = self._heads_to_indices(alive_q, self.hd, device)
topo.alive_k_out = self._heads_to_indices(alive_kv, self.hd, device)
topo.alive_v_out = self._heads_to_indices(alive_kv, self.hd, device)
topo.alive_o_out = self._sample_groups(
rng, self.n_hg, self.a_hg, self.hs, self.min_hds, self.max_hds, device
)
stats["q"].append(len(topo.alive_q_out) / self.q_dim)
stats["k"].append(len(topo.alive_k_out) / self.k_dim)
stats["v"].append(len(topo.alive_v_out) / self.v_dim)
stats["o"].append(len(topo.alive_o_out) / self.hs)
sliced_elements += (len(topo.alive_q_out) + len(topo.alive_k_out) + len(topo.alive_v_out)) * self.hs
sliced_elements += len(topo.alive_o_out) * (topo.n_alive_q_heads * self.hd)
full_elements += (self.q_dim + self.k_dim + self.v_dim + self.hs) * self.hs
if topo.mode in ("both", "mlp"):
channels = self._sample_groups(
rng, self.n_cg, self.a_cg, self.inter, self.min_cs, self.max_cs, device
)
core_indices = self._subsample(channels, self.mgg, self.min_mgs, self.max_mgs, rng)
topo.alive_gate_out = core_indices
topo.alive_up_out = core_indices
topo.alive_down_in = core_indices
topo.alive_down_out = self._sample_groups(
rng, self.n_hg, self.a_hg, self.hs, self.min_hds, self.max_hds, device
)
stats["gate"].append(len(topo.alive_gate_out) / self.inter)
stats["up"].append(len(topo.alive_up_out) / self.inter)
stats["down"].append(len(topo.alive_down_out) / self.hs)
sliced_elements += (len(topo.alive_gate_out) + len(topo.alive_up_out)) * self.hs
sliced_elements += len(topo.alive_down_out) * len(topo.alive_down_in)
full_elements += 3 * self.inter * self.hs
self.topologies[layer] = topo
active_count = sum(1 for layer in self.victims if layer in self.active_layers)
mode_counts = {"both": 0, "attn": 0, "mlp": 0, "identity": 0, "dead": 0}
for layer in self.victims:
mode_counts[modes[layer]] += 1
avg = lambda values: sum(values) / len(values) if values else 0.0
compute_ratio = sliced_elements / max(1, full_elements)
self.cached_stats = {
"shuffle_step": global_step,
"shuffle_block": block_idx,
"active_layers": active_count,
"layer_density_pct": 100 * active_count / max(1, num_victims),
"mode_both": mode_counts["both"],
"mode_attn_only": mode_counts["attn"],
"mode_mlp_only": mode_counts["mlp"],
"mode_identity": mode_counts["identity"],
"mode_dead": mode_counts["dead"],
"avg_q_surv": 100 * avg(stats["q"]),
"avg_k_surv": 100 * avg(stats["k"]),
"avg_v_surv": 100 * avg(stats["v"]),
"avg_o_surv": 100 * avg(stats["o"]),
"avg_gate_surv": 100 * avg(stats["gate"]),
"avg_up_surv": 100 * avg(stats["up"]),
"avg_down_surv": 100 * avg(stats["down"]),
"compute_ratio": compute_ratio,
"compute_pct": 100 * compute_ratio,
}
print(
"Deep Chaos Shuffle "
f"@ step {global_step} block={block_idx} "
f"compute={self.cached_stats['compute_pct']:.1f}% "
f"active={active_count} "
f"both={mode_counts['both']} attn={mode_counts['attn']} "
f"mlp={mode_counts['mlp']} identity={mode_counts['identity']} dead={mode_counts['dead']}"
)
return self.cached_stats
def _make_forward(self, layer_idx, layer):
scheduler = self
def forward(
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
cache_position=None,
position_embeddings=None,
**kwargs,
):
if not scheduler.model.training:
return scheduler.original_forwards[layer_idx](
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
topo = scheduler.topologies.get(layer_idx)
if topo is None or topo.mode in ("dead", "identity"):
return (hidden_states, None, past_key_value)
residual = hidden_states
batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1]
if topo.mode in ("both", "attn"):
normed = layer.input_layernorm(hidden_states)
q = SlicedLinear.forward(layer.self_attn.q_proj, normed, alive_out=topo.alive_q_out, compact=True)
k = SlicedLinear.forward(layer.self_attn.k_proj, normed, alive_out=topo.alive_k_out, compact=True)
v = SlicedLinear.forward(layer.self_attn.v_proj, normed, alive_out=topo.alive_v_out, compact=True)
n_alive_q = len(topo.alive_q_heads or [])
n_alive_kv = len(topo.alive_kv_heads or [])
expected_q_width = n_alive_q * scheduler.hd
expected_kv_width = n_alive_kv * scheduler.hd
if q.shape[-1] != expected_q_width:
raise RuntimeError(
f"Layer {layer_idx}: q width mismatch {q.shape[-1]} != {expected_q_width}"
)
if k.shape[-1] != expected_kv_width:
raise RuntimeError(
f"Layer {layer_idx}: k width mismatch {k.shape[-1]} != {expected_kv_width}"
)
if v.shape[-1] != expected_kv_width:
raise RuntimeError(
f"Layer {layer_idx}: v width mismatch {v.shape[-1]} != {expected_kv_width}"
)
q_4d = q.view(batch_size, seq_len, n_alive_q, scheduler.hd).transpose(1, 2)
k_4d = k.view(batch_size, seq_len, n_alive_kv, scheduler.hd).transpose(1, 2)
v_4d = v.view(batch_size, seq_len, n_alive_kv, scheduler.hd).transpose(1, 2)
if scheduler.gqa_r > 1:
k_4d = k_4d.repeat_interleave(scheduler.gqa_r, dim=1)
v_4d = v_4d.repeat_interleave(scheduler.gqa_r, dim=1)
if k_4d.shape[1] != q_4d.shape[1] or v_4d.shape[1] != q_4d.shape[1]:
raise RuntimeError(
f"Layer {layer_idx}: GQA head mismatch "
f"q={q_4d.shape[1]} k={k_4d.shape[1]} v={v_4d.shape[1]}"
)
if position_embeddings is not None:
cos, sin = position_embeddings
q_4d, k_4d = scheduler._apply_rotary(q_4d, k_4d, cos, sin, position_ids)
attn_mask_4d = None
if attention_mask is not None and attention_mask.dim() == 4:
if attention_mask.shape[1] in (1, n_alive_q):
attn_mask_4d = attention_mask
elif attention_mask.shape[1] == scheduler.nh:
raise RuntimeError(
"Head-specific 4D attention masks are not supported with compact deep-chaos heads"
)
else:
attn_mask_4d = attention_mask
elif attention_mask is not None and attention_mask.dim() == 2:
attn_mask_4d = attention_mask[:, None, None, :].to(dtype=q_4d.dtype)
attn_mask_4d = (1.0 - attn_mask_4d) * torch.finfo(q_4d.dtype).min
attn_out = F.scaled_dot_product_attention(
q_4d,
k_4d,
v_4d,
attn_mask=attn_mask_4d,
dropout_p=0.0,
is_causal=(attn_mask_4d is None),
)
if attn_out.shape[1] != n_alive_q or attn_out.shape[3] != scheduler.hd:
raise RuntimeError(
f"Layer {layer_idx}: attention output shape mismatch {tuple(attn_out.shape)}"
)
attn_compact = attn_out.transpose(1, 2).contiguous().reshape(batch_size, seq_len, -1)
alive_o_in = scheduler._heads_to_indices(topo.alive_q_heads, scheduler.hd, hidden_states.device)
if attn_compact.shape[-1] != alive_o_in.numel():
raise RuntimeError(
f"Layer {layer_idx}: compact attention width mismatch "
f"{attn_compact.shape[-1]} != {alive_o_in.numel()}"
)
o_weight = layer.self_attn.o_proj.weight[topo.alive_o_out][:, alive_o_in]
o_bias = (
layer.self_attn.o_proj.bias[topo.alive_o_out]
if layer.self_attn.o_proj.bias is not None
else None
)
attn_compact_out = F.linear(attn_compact, o_weight, o_bias)
attn_hidden = hidden_states.new_zeros(batch_size, seq_len, scheduler.hs)
attn_hidden[..., topo.alive_o_out] = attn_compact_out
hidden_states = residual + attn_hidden
else:
hidden_states = residual
residual = hidden_states
if topo.mode in ("both", "mlp"):
normed = layer.post_attention_layernorm(hidden_states)
gate = SlicedLinear.forward(layer.mlp.gate_proj, normed, alive_out=topo.alive_gate_out, compact=True)
up = SlicedLinear.forward(layer.mlp.up_proj, normed, alive_out=topo.alive_up_out, compact=True)
if gate.shape[-1] != topo.alive_down_in.numel() or up.shape[-1] != topo.alive_down_in.numel():
raise RuntimeError(
f"Layer {layer_idx}: SwiGLU core mismatch "
f"gate={gate.shape[-1]} up={up.shape[-1]} core={topo.alive_down_in.numel()}"
)
activated = F.silu(gate) * up
weight = layer.mlp.down_proj.weight[topo.alive_down_out][:, topo.alive_down_in]
bias = (
layer.mlp.down_proj.bias[topo.alive_down_out]
if layer.mlp.down_proj.bias is not None
else None
)
mlp_compact = F.linear(activated, weight, bias)
mlp_hidden = hidden_states.new_zeros(batch_size, seq_len, scheduler.hs)
mlp_hidden[..., topo.alive_down_out] = mlp_compact
hidden_states = residual + mlp_hidden
else:
hidden_states = residual
return (hidden_states, None, past_key_value)
return forward
def freeze_topology(self, step):
self.step(step)
def remove(self):
layers = resolve_transformer_layers(self.model)
for idx, original in self.original_forwards.items():
layers[idx].forward = original
self.original_forwards.clear()
self.topologies.clear()
self.cached_stats = None
self.last_shuffle_step = None
print("DeepChaosScheduler removed")
deep_chaos = None
if C.DEEP_CHAOS_ENABLED:
deep_chaos = DeepChaosScheduler(
model,
C.DEEP_CHAOS_SACRED_LAYERS,
C.DEEP_CHAOS_VICTIM_RANGE,
C.DEEP_CHAOS_MIN_LAYER_SURVIVAL,
C.DEEP_CHAOS_MAX_LAYER_SURVIVAL,
C.DEEP_CHAOS_MIN_HEAD_SURVIVAL,
C.DEEP_CHAOS_MAX_HEAD_SURVIVAL,
C.DEEP_CHAOS_CHANNEL_GROUP_SIZE,
C.DEEP_CHAOS_MIN_CHANNEL_SURVIVAL,
C.DEEP_CHAOS_MAX_CHANNEL_SURVIVAL,
C.DEEP_CHAOS_MLP_GATE_GROUP_SIZE,
C.DEEP_CHAOS_MIN_MLP_GATE_SURVIVAL,
C.DEEP_CHAOS_MAX_MLP_GATE_SURVIVAL,
C.DEEP_CHAOS_HIDDEN_GROUP_SIZE,
C.DEEP_CHAOS_MIN_HIDDEN_SURVIVAL,
C.DEEP_CHAOS_MAX_HIDDEN_SURVIVAL,
C.DEEP_CHAOS_MAX_CONSECUTIVE_ON,
C.DEEP_CHAOS_MAX_CONSECUTIVE_OFF,
C.DEEP_CHAOS_STICKY_INTERVAL,
C.DEEP_CHAOS_SEED,
)
# fill these in as you see fit this was just for my bartender personality model
PROBE_TOKENS = {
"related": ["gritty", "casual", "weird", "bartender", "conversation", "direct", "physical", "grounded"],
"unrelated": ["certainly", "apologize", "assist", "helpful", "serving", "comply", "obedient", "dutifully"],
}
def get_token_ids(tok, words):
ids = []
for word in words:
encoded = tok.encode(word, add_special_tokens=False)
if encoded:
ids.append(encoded[0])
return ids
PROBE_RELATED_IDS = get_token_ids(tokenizer, PROBE_TOKENS["related"])
PROBE_UNRELATED_IDS = get_token_ids(tokenizer, PROBE_TOKENS["unrelated"])
PROBE_ALL_IDS = PROBE_RELATED_IDS + PROBE_UNRELATED_IDS
def _capture_probe_state():
if injector_mgr is None:
return None
noise_scale = None
if injector_mgr.injectors:
noise_scale = next(iter(injector_mgr.injectors.values())).noise_scale
return {
"noise_scale": noise_scale,
"assistant_mask": injector_mgr.current_assistant_mask,
}
def _disable_injectors_for_probe():
if injector_mgr is None:
return
injector_mgr.set_noise_scale(0.0)
injector_mgr.current_assistant_mask = None
injector_mgr.clear_cache()
def _restore_probe_state(state):
if injector_mgr is None or state is None:
return
if state["noise_scale"] is not None:
injector_mgr.set_noise_scale(state["noise_scale"])
injector_mgr.current_assistant_mask = state["assistant_mask"]
def extract_layer_embeddings(target_model, tok, token_ids, layer_range, use_chaos=False):
activations = {}
hooks = []
layers = resolve_transformer_layers(target_model)
was_training = target_model.training
probe_state = _capture_probe_state()
if use_chaos:
target_model.train()
else:
target_model.eval()
_disable_injectors_for_probe()
def make_hook(layer_idx):
def hook_fn(module, inputs, output):
hidden = output[0] if isinstance(output, tuple) else output
activations[layer_idx] = hidden.detach().float().cpu().numpy()
return hook_fn
for layer_idx in range(layer_range[0], layer_range[1] + 1):
hooks.append(layers[layer_idx].register_forward_hook(make_hook(layer_idx)))
input_ids = torch.tensor([token_ids], device=model_device)
with torch.no_grad():
target_model(input_ids, return_dict=True)
for hook in hooks:
hook.remove()
_restore_probe_state(probe_state)
if was_training:
target_model.train()
else:
target_model.eval()
return activations
def compute_sim_matrix(embeddings):
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms)
normed = embeddings / norms
return normed @ normed.T
def compute_pair_sim(a, b):
a_flat = a.reshape(-1).astype(np.float32)
b_flat = b.reshape(-1).astype(np.float32)
a_norm = np.linalg.norm(a_flat)
b_norm = np.linalg.norm(b_flat)
return float(np.dot(a_flat, b_flat) / (a_norm * b_norm)) if a_norm > 0 and b_norm > 0 else 0.0
def run_bol_probe(target_model, tok, step, layer_range=C.PROBE_LAYER_RANGE, use_chaos=False):
related = get_token_ids(tok, PROBE_TOKENS["related"])
unrelated = get_token_ids(tok, PROBE_TOKENS["unrelated"])
all_ids = related + unrelated
n_related = len(related)
embeddings = extract_layer_embeddings(target_model, tok, all_ids, layer_range, use_chaos=use_chaos)
results = {"step": step, "layers": {}, "pairs": {}, "mode": "chaos" if use_chaos else "dense"}
for layer_idx in sorted(embeddings.keys()):
emb = embeddings[layer_idx][0]
sim = compute_sim_matrix(emb)
related_block = sim[:n_related, :n_related]
unrelated_block = sim[n_related:, n_related:]
mask_r = ~np.eye(n_related, dtype=bool)
mask_u = ~np.eye(len(unrelated), dtype=bool)
mean_r = related_block[mask_r].mean() if mask_r.sum() > 0 else 0.0
mean_u = unrelated_block[mask_u].mean() if mask_u.sum() > 0 else 0.0
results["layers"][layer_idx] = {
"identity_sim": float(mean_r),
"deference_sim": float(mean_u),
"separation": float(mean_r - mean_u),
"activation_variance": float(np.var(emb)),
}
injector_layers = [layer for layer in C.INJECTOR_LAYERS if layer in embeddings]
late_layers = [layer for layer in range(max(layer_range[1] - 2, layer_range[0]), layer_range[1] + 1) if layer in embeddings]
for group_name, group_layers in {"injector": injector_layers, "late": late_layers}.items():
for i, layer_a in enumerate(group_layers):
for layer_b in group_layers[i + 1 :]:
results["pairs"][f"{group_name}_{layer_a}_vs_{layer_b}"] = compute_pair_sim(
embeddings[layer_a][0],
embeddings[layer_b][0],
)
injector_vals = [value for key, value in results["pairs"].items() if key.startswith("injector_")]
late_vals = [value for key, value in results["pairs"].items() if key.startswith("late_")]
results["summary"] = {
"mean_separation": float(np.mean([v["separation"] for v in results["layers"].values()])),
"mean_injector_redundancy": float(np.mean(injector_vals)) if injector_vals else 0.0,
"mean_late_redundancy": float(np.mean(late_vals)) if late_vals else 0.0,
}
return results
def print_probe_report(results):
print(f"\n{'=' * 65}")
print(f"BoL PROBE [{results.get('mode', 'dense').upper()}] @ step {results['step']}")
print(f"{'=' * 65}")
print(f" Mean Separation: {results['summary']['mean_separation']:.4f}")
print(f" Injector Redundancy: {results['summary']['mean_injector_redundancy']:.4f}")
print(f" Late Redundancy: {results['summary']['mean_late_redundancy']:.4f}")
print(f"{'_' * 65}")
for layer_idx in sorted(results["layers"].keys()):
values = results["layers"][layer_idx]
inj = " <-INJ" if layer_idx in C.INJECTOR_LAYERS else ""
print(
f" {layer_idx:<4} id={values['identity_sim']:.4f} def={values['deference_sim']:.4f} "
f"sep={values['separation']:.4f} var={values['activation_variance']:.4f}{inj}"
)
print(f"{'=' * 65}\n")
class ProbeCallback(TrainerCallback):
def __init__(self, tok, probe_every, deep_chaos_ref=None):
self.tokenizer = tok
self.probe_every = probe_every
self.deep_chaos = deep_chaos_ref
def on_step_end(self, args, state, control, model=None, **kwargs):
if not C.PROBE_ENABLED or state.global_step == 0 or state.global_step % self.probe_every != 0:
return control
dense = run_bol_probe(model, self.tokenizer, state.global_step, use_chaos=False)
print_probe_report(dense)
chaos_results = None
if self.deep_chaos is not None:
self.deep_chaos.freeze_topology(state.global_step)
chaos_results = run_bol_probe(model, self.tokenizer, state.global_step, use_chaos=True)
print_probe_report(chaos_results)
payload = {"dense": dense}
if chaos_results is not None:
payload["chaos"] = chaos_results
Path(C.OUTPUT_DIR, f"probe_step_{state.global_step}.json").write_text(json.dumps(payload, indent=2))
volume.commit()
if wandb.run:
wandb.log(
{
"probe_dense/separation": dense["summary"]["mean_separation"],
"probe_dense/late_redundancy": dense["summary"]["mean_late_redundancy"],
},
step=state.global_step,
)
if chaos_results is not None:
wandb.log(
{
"probe_chaos/separation": chaos_results["summary"]["mean_separation"],
"probe_chaos/late_redundancy": chaos_results["summary"]["mean_late_redundancy"],
},
step=state.global_step,
)
return control
class BoundaryProbeCallback(TrainerCallback):
def __init__(self, tok, probe_every, deep_chaos_ref=None):
self.tokenizer = tok
self.probe_every = probe_every
self.deep_chaos = deep_chaos_ref
self.prompts = [
]
self.forbidden_ids = set()
for text in ["<|im_start|>user", "<|im_start|>assistant", "user", "assistant", "User", "Assistant"]:
self.forbidden_ids.update(self.tokenizer.encode(text, add_special_tokens=False))
def _run_boundary(self, target_model):
max_forbidden = 0.0
for prompt in self.prompts:
inputs = self.tokenizer(prompt, return_tensors="pt").to(model_device)
with torch.no_grad():
probs = torch.softmax(target_model(**inputs).logits[0, -1, :], dim=-1)
max_forbidden = max(
max_forbidden,
sum(probs[token_id].item() for token_id in self.forbidden_ids if token_id < probs.shape[0]),
)
return max_forbidden
def on_step_end(self, args, state, control, model=None, **kwargs):
if model is None or state.global_step == 0 or state.global_step % self.probe_every != 0:
return control
was_training = model.training
probe_state = _capture_probe_state()
try:
model.eval()
_disable_injectors_for_probe()
dense_forbidden = self._run_boundary(model)
print(f"Boundary [DENSE] @ {state.global_step}: max_forbidden={dense_forbidden:.6f}")
chaos_forbidden = 0.0
if self.deep_chaos is not None:
model.train()
self.deep_chaos.freeze_topology(state.global_step)
_disable_injectors_for_probe()
chaos_forbidden = self._run_boundary(model)
print(f"Boundary [CHAOS] @ {state.global_step}: max_forbidden={chaos_forbidden:.6f}")
if wandb.run:
wandb.log(
{
"boundary/dense_max_forbidden": dense_forbidden,
"boundary/chaos_max_forbidden": chaos_forbidden,
},
step=state.global_step,
)
finally:
_restore_probe_state(probe_state)
if was_training:
model.train()
else:
model.eval()
return control
if not getattr(tokenizer, "chat_template", None):
tokenizer = get_chat_template(
tokenizer,
chat_template="llama-3",
mapping={"role": "role", "content": "content", "user": "user", "assistant": "assistant"},
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
dataset = load_dataset(C.DATASET_NAME, split=C.DATASET_SPLIT).shuffle(seed=42)
print(f"Dataset: {C.DATASET_NAME} ({len(dataset)} samples)")
def tokenize_conversation(convo):
return tokenizer.apply_chat_template(
convo,
tokenize=True,
add_generation_prompt=False,
return_dict=True,
truncation=True,
max_length=C.MAX_SEQ_LENGTH,
)
def build_labels(convo, input_ids):
labels = [-100] * len(input_ids)
for idx, message in enumerate(convo):
if message.get("role") != "assistant":
continue
prefix_ids = tokenize_conversation(convo[:idx])["input_ids"] if idx > 0 else []
through_ids = tokenize_conversation(convo[: idx + 1])["input_ids"]
for pos in range(min(len(prefix_ids), len(input_ids)), min(len(through_ids), len(input_ids))):
labels[pos] = input_ids[pos]
return labels
def format_batch(examples):
input_ids_list = []
attention_masks_list = []
labels_list = []
for convo in examples["messages"]:
tokenized = tokenize_conversation(convo)
input_ids_list.append(tokenized["input_ids"])
attention_masks_list.append(tokenized["attention_mask"])
labels_list.append(build_labels(convo, tokenized["input_ids"]))
return {
"input_ids": input_ids_list,
"attention_mask": attention_masks_list,
"labels": labels_list,
}
dataset_formatted = standardize_sharegpt(dataset)
dataset_formatted = dataset_formatted.map(format_batch, batched=True, remove_columns=dataset_formatted.column_names)
os.makedirs(C.OUTPUT_DIR, exist_ok=True)
if C.PROBE_ENABLED:
print("Baseline probe (dense)...")
baseline_dense = run_bol_probe(model, tokenizer, 0, use_chaos=False)
print_probe_report(baseline_dense)
baseline_chaos = None
if deep_chaos is not None:
print("Baseline probe (chaos)...")
deep_chaos.freeze_topology(0)
baseline_chaos = run_bol_probe(model, tokenizer, 0, use_chaos=True)
print_probe_report(baseline_chaos)
Path(C.OUTPUT_DIR, "probe_baseline.json").write_text(
json.dumps({"dense": baseline_dense, "chaos": baseline_chaos}, indent=2)
)
wandb.init(project=C.WANDB_PROJECT, name=C.WANDB_RUN_NAME)
wandb.config.update(config)
class DeepChaosSFTTrainer(Trainer):
injector_mgr = None
base_model_ref = None
deep_chaos = None
deep_chaos_stats = None
def compute_loss(self, target_model, inputs, return_outputs=False, **kwargs):
if self.deep_chaos is not None:
self.deep_chaos_stats = self.deep_chaos.step(self.state.global_step)
labels = inputs.get("labels")
model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
assistant_mask = None
clean_cache = {}
if self.injector_mgr is not None and labels is not None:
assistant_mask = (labels != -100).float()
self.injector_mgr.current_assistant_mask = assistant_mask
self.injector_mgr.set_noise_scale(C.INJECTOR_NOISE_SCALE)
if C.STABILITY_LOSS_WEIGHT > 0:
self.injector_mgr.set_noise_scale(0.0)
with torch.no_grad():
target_model(**model_inputs)
clean_cache = self.injector_mgr.snapshot_cache(use_post=C.STABILITY_ON_POST_INJECTION, detach=True)
self.injector_mgr.clear_cache()
self.injector_mgr.current_assistant_mask = assistant_mask
self.injector_mgr.set_noise_scale(C.INJECTOR_NOISE_SCALE)
outputs = target_model(**model_inputs)
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.CrossEntropyLoss(ignore_index=-100)(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
ortho = torch.tensor(0.0, device=target_model.device)
stab = torch.tensor(0.0, device=target_model.device)
pres = torch.tensor(0.0, device=target_model.device)
decorr = torch.tensor(0.0, device=target_model.device)
if self.injector_mgr is not None:
ortho = self.injector_mgr.compute_ortho_loss(use_post=not C.ORTHO_ON_PRE_INJECTION)
decorr = self.injector_mgr.compute_decorrelation_loss(mask=assistant_mask, use_post=True)
if clean_cache:
perturbed = self.injector_mgr.snapshot_cache(use_post=C.STABILITY_ON_POST_INJECTION, detach=False)
stab = self.injector_mgr.compute_stability_loss(clean_cache, perturbed, assistant_mask)
if self.base_model_ref is not None and C.PRESERVATION_KL_WEIGHT > 0 and labels is not None:
with torch.no_grad():
base_outputs = self.base_model_ref(**model_inputs)
base_logits = base_outputs.logits[..., :-1, :].contiguous()
tuned_logits = outputs.logits[..., :-1, :].contiguous()
policy_mask = (labels[..., 1:].contiguous() == -100).float()
if policy_mask.sum().item() > 0:
log_p_base = F.log_softmax(base_logits, dim=-1)
log_p_tuned = F.log_softmax(tuned_logits, dim=-1)
kl = (log_p_base.exp() * (log_p_base - log_p_tuned)).sum(-1)
pres = (kl * policy_mask).sum() / policy_mask.sum().clamp_min(1)
main_loss = (
loss
+ ortho
+ C.STABILITY_LOSS_WEIGHT * stab
+ C.PRESERVATION_KL_WEIGHT * pres
+ C.DECORRELATION_WEIGHT * decorr
)
if self.injector_mgr is not None:
self.injector_mgr.clear_cache()
if self.state.global_step % C.LOGGING_STEPS == 0 and wandb.run:
payload = {
"loss/main": float(main_loss),
"loss/task": float(loss),
"ortho/loss": float(ortho),
"stability/loss": float(stab),
"preservation/kl": float(pres),
"decorrelation/loss": float(decorr),
}
if self.deep_chaos_stats is not None:
payload.update({f"dc/{k}": float(v) for k, v in self.deep_chaos_stats.items()})
wandb.log(payload, step=self.state.global_step)
return (main_loss, outputs) if return_outputs else main_loss
def create_optimizer(self):
super().create_optimizer()
if self.injector_mgr is not None:
injector_params = self.injector_mgr.get_trainable_params()
if injector_params:
self.optimizer.add_param_group({"params": injector_params, "lr": C.LEARNING_RATE, "weight_decay": 0.0})
return self.optimizer
training_args = TrainingArguments(
output_dir=C.OUTPUT_DIR,
num_train_epochs=C.NUM_EPOCHS,
per_device_train_batch_size=C.BATCH_SIZE,
gradient_accumulation_steps=C.GRADIENT_ACCUMULATION,
learning_rate=C.LEARNING_RATE,
lr_scheduler_type=C.LR_SCHEDULER,
warmup_ratio=C.WARMUP_RATIO,
weight_decay=C.WEIGHT_DECAY,
max_grad_norm=C.MAX_GRAD_NORM,
optim="adamw_torch",
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
logging_steps=C.LOGGING_STEPS,
save_steps=C.SAVE_STEPS,
save_total_limit=2,
gradient_checkpointing=True,
report_to="wandb",
run_name=C.WANDB_RUN_NAME,
remove_unused_columns=False,
seed=42,
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
padding=True,
label_pad_token_id=-100,
return_tensors="pt",
)
trainer = DeepChaosSFTTrainer(
model=model,
args=training_args,
train_dataset=dataset_formatted,
processing_class=tokenizer,
data_collator=data_collator,
callbacks=[
ProbeCallback(tokenizer, C.PROBE_STEPS, deep_chaos),
BoundaryProbeCallback(tokenizer, C.PROBE_STEPS, deep_chaos),
],
)
trainer.injector_mgr = injector_mgr
trainer.base_model_ref = base_model_ref
trainer.deep_chaos = deep_chaos
sanity_batch = next(iter(trainer.get_train_dataloader()))
supervised_tokens = (sanity_batch["labels"] != -100).sum().item()
total_tokens = sanity_batch["labels"].numel()
if supervised_tokens <= 0 or supervised_tokens >= total_tokens:
raise RuntimeError(f"Label sanity failed: sup={supervised_tokens} total={total_tokens}")
print(f"Label sanity: supervised={supervised_tokens} masked={total_tokens - supervised_tokens}")
if deep_chaos is not None:
print("\nSmoke test...")
smoke_stats = deep_chaos.step(0)
print(f" Compute est: {smoke_stats['compute_pct']:.1f}% Layers: {smoke_stats['active_layers']}")
smoke_inputs = sanity_batch["input_ids"][:1, :64].to(model_device)
smoke_mask = sanity_batch["attention_mask"][:1, :64].to(model_device)
model.train()
with torch.no_grad():
smoke_outputs = model(input_ids=smoke_inputs, attention_mask=smoke_mask, return_dict=True)
if not torch.isfinite(smoke_outputs.logits).all():
raise RuntimeError("Smoke test produced non-finite logits")
print(f" Forward pass OK: {smoke_outputs.logits.shape}")
print(f"\n{'=' * 60}")
print("READY — DEEP CHAOS v3")
print(f"{'=' * 60}")
print(f" Model: {C.MODEL_NAME}")
print(f" Dataset: {len(dataset_formatted)} samples")
print(f" Deep Chaos: {C.DEEP_CHAOS_ENABLED}")
print(f" Train layers: {freezing_info['train_start']}-{freezing_info['train_end']} of {freezing_info['total_layers'] - 1}")
print(f" Export: {C.HF_REPO}")
print(f"{'=' * 60}")
trainer_stats = trainer.train()
print(f"\nDone. Loss: {trainer_stats.metrics['train_loss']:.4f}")
if injector_mgr is not None:
injector_mgr.remove_hooks()
if deep_chaos is not None:
deep_chaos.remove()
final_dir = os.path.join(C.OUTPUT_DIR, "final")
model.save_pretrained(final_dir, safe_serialization=True)
tokenizer.save_pretrained(final_dir)
volume.commit()
api = HfApi()
api.create_repo(repo_id=C.HF_REPO, private=True, exist_ok=True)
api.upload_folder(repo_id=C.HF_REPO, repo_type="model", folder_path=final_dir)
print(f"Hub push complete: {C.HF_REPO}")
@app.local_entrypoint()
def main():
print("")
train.remote()