lucky_pick_scheduler

Community Article Published April 15, 2026
import modal
import os


app = modal.App("")

image = (
    modal.Image.from_registry(
        "nvidia/cuda:12.1.1-devel-ubuntu22.04",
        add_python="3.11",
    )
    .apt_install("git", "gcc", "g++")
    .pip_install(
        "torch==2.5.1",
        "torchvision==0.20.1",
        "torchaudio==2.5.1",
        index_url="https://download.pytorch.org/whl/cu121",
    )
    .pip_install("unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git")
    .pip_install(
        "unsloth_zoo",
        "trl",
        "peft",
        "accelerate",
        "bitsandbytes",
        "datasets",
        "wandb",
        "protobuf",
    )
    .run_commands("pip uninstall -y torchao")
)

volume = modal.Volume.from_name("bella-v6-deep-chaos-v3", create_if_missing=True)


CONFIG = dict(
    MODEL_NAME="",
    MAX_SEQ_LENGTH=,
    LOAD_IN_4BIT=,
    MODEL_DTYPE="",
    DATASET_NAME="",
    DATASET_SPLIT="train",
    INJECTOR_ENABLED=,
    INJECTOR_LAYERS=[6, 12, 18],
    INJECTOR_HIDDEN_DIM=128,
    INJECTOR_NOISE_SCALE=0.0,
    DECORRELATION_WEIGHT=0.0,
    PROBE_ALIGNMENT_WEIGHT=0.0,
    PROBE_ALIGNMENT_EVERY=1,
    ORTHO_PEAK_LAYERS=[4, 12, 21],
    ORTHO_LOSS_WEIGHT=0.0,
    ORTHO_ON_PRE_INJECTION=True,
    ORTHO_METRIC="token_wise",
    STABILITY_LOSS_WEIGHT=0.0,
    STABILITY_ON_POST_INJECTION=True,
    PRESERVATION_KL_WEIGHT=0.0,
    LAYER_FREEZE_FRACTION=0.0,
    TRAIN_LAYER_RANGE=(0, 27),
    DEEP_CHAOS_ENABLED=True,
    DEEP_CHAOS_SACRED_LAYERS=[0, 1, 26, 27],
    DEEP_CHAOS_VICTIM_RANGE=(2, 26),
    DEEP_CHAOS_MIN_LAYER_SURVIVAL=0.30,
    DEEP_CHAOS_MAX_LAYER_SURVIVAL=0.70,
    DEEP_CHAOS_MIN_HEAD_SURVIVAL=0.30,
    DEEP_CHAOS_MAX_HEAD_SURVIVAL=0.70,
    DEEP_CHAOS_CHANNEL_GROUP_SIZE=128,
    DEEP_CHAOS_MIN_CHANNEL_SURVIVAL=0.30,
    DEEP_CHAOS_MAX_CHANNEL_SURVIVAL=0.70,
    DEEP_CHAOS_MLP_GATE_GROUP_SIZE=128,
    DEEP_CHAOS_MIN_MLP_GATE_SURVIVAL=0.35,
    DEEP_CHAOS_MAX_MLP_GATE_SURVIVAL=0.80,
    DEEP_CHAOS_HIDDEN_GROUP_SIZE=64,
    DEEP_CHAOS_MIN_HIDDEN_SURVIVAL=0.60,
    DEEP_CHAOS_MAX_HIDDEN_SURVIVAL=0.95,
    DEEP_CHAOS_MAX_CONSECUTIVE_ON=5,
    DEEP_CHAOS_MAX_CONSECUTIVE_OFF=10,
    DEEP_CHAOS_STICKY_INTERVAL=50,
    DEEP_CHAOS_SEED=42,
    NUM_EPOCHS=,
    BATCH_SIZE=2,
    GRADIENT_ACCUMULATION=2,     #2x2 works best for me dont know why 
    LEARNING_RATE=2e-5,
    LR_SCHEDULER="cosine",
    WARMUP_RATIO=0.15,
    WEIGHT_DECAY=0.01,
    MAX_GRAD_NORM=1.0,
    LOGGING_STEPS=10,
    SAVE_STEPS=2000,
    PROBE_STEPS=100,
    OUTPUT_DIR="",
    HF_REPO="",
    WANDB_PROJECT="",
    WANDB_RUN_NAME="",
    PROBE_ENABLED=True,
    PROBE_LAYER_RANGE=(2, 26),
)


@app.function(
    image=image,
    gpu="",
    timeout=6 * 3600,
    volumes={"/vol": volume},
    secrets=[
        modal.Secret.from_name(""),
        modal.Secret.from_name(""),
    ],
)
def train(config: dict = CONFIG):
    import copy
    import json
    import math
    import random
    from dataclasses import dataclass
    from pathlib import Path

    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import wandb
    from datasets import load_dataset
    from huggingface_hub import HfApi
    from unsloth import FastLanguageModel, is_bfloat16_supported
    from unsloth.chat_templates import get_chat_template, standardize_sharegpt
    from transformers import DataCollatorForSeq2Seq, Trainer, TrainerCallback, TrainingArguments
    from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

    C = type("C", (), config)()

    def resolve_dtype(name):
        return {
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
            "float32": torch.float32,
        }[name]

    def resolve_transformer_layers(target_model):
        if hasattr(target_model, "model") and hasattr(target_model.model, "layers"):
            return target_model.model.layers
        if hasattr(target_model, "base_model"):
            base = target_model.base_model
            if hasattr(base, "model") and hasattr(base.model, "layers"):
                return base.model.layers
            if hasattr(base, "model") and hasattr(base.model, "model") and hasattr(base.model.model, "layers"):
                return base.model.model.layers
        raise AttributeError("Could not find transformer layers")

    def resolve_transformer_root(target_model):
        if hasattr(target_model, "model") and hasattr(target_model.model, "layers"):
            return target_model.model
        if hasattr(target_model, "base_model"):
            base = target_model.base_model
            if hasattr(base, "model") and hasattr(base.model, "layers"):
                return base.model
            if hasattr(base, "model") and hasattr(base.model, "model") and hasattr(base.model.model, "layers"):
                return base.model.model
        raise AttributeError("Could not find transformer root")

    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"PyTorch: {torch.__version__}  CUDA: {torch.version.cuda}")

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=C.MODEL_NAME,
        max_seq_length=C.MAX_SEQ_LENGTH,
        dtype=resolve_dtype(C.MODEL_DTYPE),
        load_in_4bit=C.LOAD_IN_4BIT,
    )
    model.enable_input_require_grads()
    model_device = next(model.parameters()).device
    print(f"Model loaded: {C.MODEL_NAME}  VRAM: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

    def setup_layer_freezing(target_model):
        layers = resolve_transformer_layers(target_model)
        total = len(layers)
        train_start, train_end = C.TRAIN_LAYER_RANGE if C.TRAIN_LAYER_RANGE else (0, total - 1)
        train_start = max(0, min(train_start, total - 1))
        train_end = max(train_start, min(train_end, total - 1))
        frozen = 0
        trainable = 0
        for name, param in target_model.named_parameters():
            should_train = True
            if "embed" in name or "lm_head" in name:
                should_train = False
            elif "layers." in name:
                layer_idx = int(name.split("layers.")[1].split(".")[0])
                if layer_idx < train_start or layer_idx > train_end:
                    should_train = False
            param.requires_grad = should_train
            if should_train:
                trainable += 1
            else:
                frozen += 1
        print(f"Freeze: train layers {train_start}-{train_end}/{total - 1}  frozen={frozen} trainable={trainable}")
        return {"total_layers": total, "train_start": train_start, "train_end": train_end}

    freezing_info = setup_layer_freezing(model)
    base_model_ref = None
    if C.PRESERVATION_KL_WEIGHT > 0:
        base_model_ref = copy.deepcopy(model)
        base_model_ref.eval()
        for param in base_model_ref.parameters():
            param.requires_grad = False

    class EntropyInjector(nn.Module):
        def __init__(self, hidden_dim, injector_hidden_dim, noise_scale):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(hidden_dim, injector_hidden_dim, bias=False),
                nn.GELU(),
                nn.Linear(injector_hidden_dim, hidden_dim, bias=False),
            )
            self.norm = nn.LayerNorm(hidden_dim, elementwise_affine=False)
            self.noise_scale = noise_scale
            nn.init.orthogonal_(self.net[0].weight, gain=math.sqrt(2))
            nn.init.zeros_(self.net[2].weight)

        def forward(self, hidden_states, assistant_mask=None):
            dtype = hidden_states.dtype
            perturb = self.norm(self.net(hidden_states))
            perturb = torch.nan_to_num(perturb * self.noise_scale, nan=0.0).clamp(-1, 1)
            if assistant_mask is not None:
                perturb = perturb * assistant_mask.to(dtype).unsqueeze(-1)
            return torch.nan_to_num(hidden_states + perturb, nan=0.0).to(dtype)

    class InjectorManager:
        def __init__(self, target_model, layers, hidden_dim, injector_hidden_dim, noise_scale, ortho_weight, ortho_on_pre, ortho_metric, device):
            self.model = target_model
            self.injector_layers = layers
            self.ortho_loss_weight = ortho_weight
            self.ortho_on_pre = ortho_on_pre
            self.ortho_metric = ortho_metric
            self.injectors = nn.ModuleDict()
            self.hooks = []
            self.cached_pre = {}
            self.cached_post = {}
            self.current_assistant_mask = None
            self.enabled_for_probe = False
            for layer_idx in layers:
                inj = EntropyInjector(hidden_dim, injector_hidden_dim, noise_scale).to(
                    device=device,
                    dtype=next(target_model.parameters()).dtype,
                )
                self.injectors[str(layer_idx)] = inj
            self._install()

        def _install(self):
            layers = resolve_transformer_layers(self.model)
            for layer_idx in self.injector_layers:
                self.hooks.append(layers[layer_idx].register_forward_hook(self._make_hook(layer_idx)))

        def _make_hook(self, layer_idx):
            injector = self.injectors[str(layer_idx)]

            def hook_fn(module, inputs, output):
                if not self.model.training and not self.enabled_for_probe:
                    return output
                if isinstance(output, tuple):
                    hidden = output[0]
                    self.cached_pre[layer_idx] = hidden
                    result = injector(hidden, self.current_assistant_mask)
                    if not torch.isfinite(result).all():
                        result = hidden
                    self.cached_post[layer_idx] = result
                    return (result,) + output[1:]
                self.cached_pre[layer_idx] = output
                result = injector(output, self.current_assistant_mask)
                if not torch.isfinite(result).all():
                    result = output
                self.cached_post[layer_idx] = result
                return result

            return hook_fn

        def set_noise_scale(self, noise_scale):
            for injector in self.injectors.values():
                injector.noise_scale = noise_scale

        def compute_ortho_loss(self, use_post=False):
            cache = self.cached_post if use_post else self.cached_pre
            if len(cache) < 2:
                return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
            loss = torch.zeros(1, device=next(self.injectors.parameters()).device)
            layers = sorted(cache.keys())
            for i in range(len(layers)):
                for j in range(i + 1, len(layers)):
                    a = cache[layers[i]]
                    b = cache[layers[j]]
                    if self.ortho_metric == "token_wise":
                        sim = F.cosine_similarity(a.view(-1, a.size(-1)), b.view(-1, b.size(-1)), dim=-1)
                        loss = loss + sim.square().mean()
                    else:
                        sim = F.cosine_similarity(a.mean((0, 1)).unsqueeze(0), b.mean((0, 1)).unsqueeze(0)).squeeze()
                        loss = loss + sim.square()
            return torch.nan_to_num(loss * self.ortho_loss_weight).clamp_max(10.0).squeeze()

        def snapshot_cache(self, use_post=False, detach=False):
            cache = self.cached_post if use_post else self.cached_pre
            return {k: v.detach() for k, v in cache.items()} if detach else dict(cache)

        def compute_stability_loss(self, clean, perturbed, mask=None):
            if not clean or not perturbed:
                return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
            losses = []
            for layer_idx in sorted(set(clean) & set(perturbed)):
                a = clean[layer_idx].to(perturbed[layer_idx].dtype)
                b = perturbed[layer_idx]
                if not torch.isfinite(a).all() or not torch.isfinite(b).all():
                    continue
                diff = (b - a).pow(2)
                if mask is not None:
                    m = mask.to(diff.dtype).unsqueeze(-1)
                    denom = m.sum() * diff.size(-1)
                    if denom.item() > 0:
                        losses.append((diff * m).sum() / denom)
                        continue
                losses.append(diff.mean())
            return torch.stack(losses).mean() if losses else torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()

        def compute_decorrelation_loss(self, mask=None, use_post=True):
            cache = self.cached_post if use_post else self.cached_pre
            layers = sorted(cache.keys())
            if len(layers) < 2:
                return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
            pooled = {}
            for layer_idx in layers:
                hidden = cache[layer_idx]
                if not torch.isfinite(hidden).all():
                    continue
                hidden_f = hidden.float()
                if mask is not None:
                    m = mask.to(hidden_f.dtype).unsqueeze(-1)
                    pooled[layer_idx] = (hidden_f * m).sum(1) / m.sum(1).clamp_min(1)
                else:
                    pooled[layer_idx] = hidden_f.mean(1)
            valid = sorted(pooled.keys())
            if len(valid) < 2:
                return torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()
            terms = []
            for i in range(len(valid)):
                for j in range(i + 1, len(valid)):
                    terms.append(F.cosine_similarity(pooled[valid[i]], pooled[valid[j]], dim=-1).square().mean())
            return torch.stack(terms).mean() if terms else torch.zeros(1, device=next(self.injectors.parameters()).device).squeeze()

        def clear_cache(self):
            self.cached_pre.clear()
            self.cached_post.clear()
            self.current_assistant_mask = None

        def get_trainable_params(self):
            return list(self.injectors.parameters())

        def remove_hooks(self):
            for hook in self.hooks:
                hook.remove()
            self.hooks.clear()

    injector_mgr = None
    if C.INJECTOR_ENABLED and C.INJECTOR_LAYERS:
        injector_mgr = InjectorManager(
            model,
            C.INJECTOR_LAYERS,
            model.config.hidden_size,
            C.INJECTOR_HIDDEN_DIM,
            C.INJECTOR_NOISE_SCALE,
            C.ORTHO_LOSS_WEIGHT,
            C.ORTHO_ON_PRE_INJECTION,
            C.ORTHO_METRIC,
            model_device,
        )

    class SlicedLinear:
        @staticmethod
        def forward(linear, x, alive_out=None, alive_in=None, compact=False):
            weight, bias = linear.weight, linear.bias
            full_out = weight.shape[0]
            batch_shape = x.shape[:-1]
            if alive_out is None and alive_in is None:
                return F.linear(x, weight, bias)
            w = weight
            b = bias
            x_in = x
            if alive_out is not None:
                w = w[alive_out]
                b = b[alive_out] if b is not None else None
            if alive_in is not None:
                w = w[:, alive_in]
                x_in = x[..., alive_in]
            partial = F.linear(x_in, w, b)
            if compact or alive_out is None:
                return partial
            if alive_out.shape[0] < full_out:
                out = x.new_zeros(*batch_shape, full_out)
                out[..., alive_out] = partial
                return out
            return partial

    @dataclass
    class LayerTopology:
        mode: str = "both"
        alive_q_heads: list | None = None
        alive_kv_heads: list | None = None
        alive_q_out: torch.Tensor | None = None
        alive_k_out: torch.Tensor | None = None
        alive_v_out: torch.Tensor | None = None
        alive_o_out: torch.Tensor | None = None
        n_alive_q_heads: int = 0
        n_alive_kv_heads: int = 0
        alive_gate_out: torch.Tensor | None = None
        alive_up_out: torch.Tensor | None = None
        alive_down_in: torch.Tensor | None = None
        alive_down_out: torch.Tensor | None = None

    class DeepChaosScheduler:
        def __init__(
            self,
            target_model,
            sacred_layers,
            victim_range,
            min_layer_survival,
            max_layer_survival,
            min_head_survival,
            max_head_survival,
            channel_group_size,
            min_channel_survival,
            max_channel_survival,
            mlp_gate_group_size,
            min_mlp_gate_survival,
            max_mlp_gate_survival,
            hidden_group_size,
            min_hidden_survival,
            max_hidden_survival,
            max_consecutive_on,
            max_consecutive_off,
            sticky_interval,
            seed,
        ):
            self.model = target_model
            self.sacred = set(sacred_layers)
            self.victims = list(range(victim_range[0], victim_range[1]))
            self.min_ls = min_layer_survival
            self.max_ls = max_layer_survival
            self.min_hs = min_head_survival
            self.max_hs = max_head_survival
            self.cgs = channel_group_size
            self.min_cs = min_channel_survival
            self.max_cs = max_channel_survival
            self.mgg = mlp_gate_group_size
            self.min_mgs = min_mlp_gate_survival
            self.max_mgs = max_mlp_gate_survival
            self.hgg = hidden_group_size
            self.min_hds = min_hidden_survival
            self.max_hds = max_hidden_survival
            self.max_on = max_consecutive_on
            self.max_off = max_consecutive_off
            self.sticky_interval = max(1, int(sticky_interval))
            self.seed = seed
            self.on_streak = {layer: 0 for layer in self.victims}
            self.off_streak = {layer: 0 for layer in self.victims}
            self.topologies = {}
            self.active_layers = set(self.sacred)
            self.last_shuffle_step = None
            self.cached_stats = None
            layers = resolve_transformer_layers(target_model)
            sample = layers[self.victims[0]]
            self.hs = sample.self_attn.q_proj.weight.shape[1]
            self.q_dim = sample.self_attn.q_proj.weight.shape[0]
            self.k_dim = sample.self_attn.k_proj.weight.shape[0]
            self.v_dim = sample.self_attn.v_proj.weight.shape[0]
            self.inter = sample.mlp.gate_proj.weight.shape[0]
            cfg = getattr(target_model, "config", None)
            self.nh = getattr(cfg, "num_attention_heads", None)
            self.nkv = getattr(cfg, "num_key_value_heads", None)
            self.hd = getattr(cfg, "head_dim", None)
            if self.nh is None:
                self.nh = getattr(sample.self_attn, "num_heads", None)
            if self.nkv is None:
                self.nkv = getattr(sample.self_attn, "num_key_value_heads", None)
            if self.hd is None:
                self.hd = getattr(sample.self_attn, "head_dim", None)
            if self.nh is None:
                raise AttributeError("Could not determine num_attention_heads from model config or attention module")
            if self.nkv is None:
                self.nkv = self.nh
            if self.hd is None:
                self.hd = self.hs // self.nh
            self.gqa_r = self.nh // self.nkv
            self.n_cg = max(1, self.inter // self.cgs)
            self.a_cg = self.inter // self.n_cg
            self.n_hg = max(1, self.hs // self.hgg)
            self.a_hg = self.hs // self.n_hg
            self.original_forwards = {}
            self._install()

        def _install(self):
            layers = resolve_transformer_layers(self.model)
            for idx in self.victims:
                self.original_forwards[idx] = layers[idx].forward
                layers[idx].forward = self._make_forward(idx, layers[idx])
            print(
                f"DeepChaosScheduler v3: victims={self.victims[0]}-{self.victims[-1]} sacred={sorted(self.sacred)}"
            )
            print(f"  heads={self.nh} kv={self.nkv} head_dim={self.hd} hidden={self.hs} inter={self.inter}")

        def _sample_groups(self, rng, num_groups, group_size, total_dim, min_rate, max_rate, device):
            rate = rng.uniform(min_rate, max_rate)
            num_alive = max(1, int(round(num_groups * rate)))
            alive_groups = sorted(rng.sample(range(num_groups), num_alive))
            indices = []
            for group in alive_groups:
                start = group * group_size
                indices.extend(range(start, min(start + group_size, total_dim)))
            return torch.tensor(indices, dtype=torch.long, device=device)

        def _heads_to_indices(self, heads, head_dim, device):
            indices = []
            for head in heads:
                indices.extend(range(head * head_dim, (head + 1) * head_dim))
            return torch.tensor(indices, dtype=torch.long, device=device)
        
        def _apply_rotary(self, q, k, cos, sin, position_ids=None):
            hd = self.hd
            seq_len = q.shape[2]

            def fix_rope(t):
                t = t.float()
                
                if t.shape[-1] != hd:
                    t = t[..., :hd]
                
                while t.dim() > 2:
                    t = t.squeeze(0)

                if t.shape[0] > seq_len:
                    if position_ids is not None:
                        t = t[position_ids]
                    else:
                        t = t[:seq_len]

                while t.dim() < 3:
                    t = t.unsqueeze(0)

                return t
            
            cos = fix_rope(cos).to(q.dtype)
            sin = fix_rope(sin).to(q.dtype)
            return apply_rotary_pos_emb(q, k, cos, sin)

        def _subsample(self, base, group_size, min_rate, max_rate, rng):
            n = len(base)
            num_groups = max(1, n // group_size)
            if num_groups <= 1:
                return base
            num_alive = max(1, int(round(num_groups * rng.uniform(min_rate, max_rate))))
            alive_groups = sorted(rng.sample(range(num_groups), num_alive))
            idx = []
            for group in alive_groups:
                start = group * group_size
                idx.extend(range(start, min(start + group_size, n)))
            return base[idx] if isinstance(base, torch.Tensor) else [base[i] for i in idx]

        def step(self, global_step):
            if (
                self.cached_stats is not None
                and self.last_shuffle_step is not None
                and global_step < self.last_shuffle_step + self.sticky_interval
            ):
                return self.cached_stats

            self.last_shuffle_step = global_step
            block_idx = global_step // self.sticky_interval
            rng = random.Random(self.seed + global_step)
            device = next(self.model.parameters()).device
            num_victims = len(self.victims)
            min_active = max(1, int(round(num_victims * self.min_ls)))
            max_active = max(min_active, int(round(num_victims * self.max_ls)))
            target = rng.randint(min_active, max_active)
            forced_on = [layer for layer in self.victims if self.off_streak[layer] >= self.max_off]
            forced_off = [layer for layer in self.victims if self.on_streak[layer] >= self.max_on]
            available = [layer for layer in self.victims if layer not in forced_on and layer not in forced_off]
            active = set(forced_on)
            need = target - len(active)
            if need > 0 and available:
                active.update(rng.sample(available, min(need, len(available))))
            for layer in self.victims:
                if layer in active:
                    self.on_streak[layer] += 1
                    self.off_streak[layer] = 0
                else:
                    self.on_streak[layer] = 0
                    self.off_streak[layer] += 1
            self.active_layers = active | self.sacred

            modes = {}
            for layer in self.victims:
                if layer not in self.active_layers:
                    modes[layer] = "dead"
                    continue
                draw = rng.random()
                if draw < 0.30:
                    modes[layer] = "both"
                elif draw < 0.55:
                    modes[layer] = "attn"
                elif draw < 0.80:
                    modes[layer] = "mlp"
                else:
                    modes[layer] = "identity"

            self.topologies.clear()
            sliced_elements = 0
            full_elements = 0
            stats = {key: [] for key in ("q", "k", "v", "o", "gate", "up", "down")}

            for layer in self.victims:
                topo = LayerTopology(mode=modes[layer])
                if topo.mode in ("dead", "identity"):
                    self.topologies[layer] = topo
                    continue

                if topo.mode in ("both", "attn"):
                    kv_rate = rng.uniform(self.min_hs, self.max_hs)
                    n_alive_kv = max(1, int(round(self.nkv * kv_rate)))
                    alive_kv = sorted(rng.sample(range(self.nkv), n_alive_kv))
                    topo.alive_kv_heads = alive_kv
                    topo.n_alive_kv_heads = len(alive_kv)
                    alive_q = []
                    for kv_head in alive_kv:
                        for q_head in range(kv_head * self.gqa_r, (kv_head + 1) * self.gqa_r):
                            if q_head < self.nh:
                                alive_q.append(q_head)
                    topo.alive_q_heads = alive_q
                    topo.n_alive_q_heads = len(alive_q)
                    topo.alive_q_out = self._heads_to_indices(alive_q, self.hd, device)
                    topo.alive_k_out = self._heads_to_indices(alive_kv, self.hd, device)
                    topo.alive_v_out = self._heads_to_indices(alive_kv, self.hd, device)
                    topo.alive_o_out = self._sample_groups(
                        rng, self.n_hg, self.a_hg, self.hs, self.min_hds, self.max_hds, device
                    )
                    stats["q"].append(len(topo.alive_q_out) / self.q_dim)
                    stats["k"].append(len(topo.alive_k_out) / self.k_dim)
                    stats["v"].append(len(topo.alive_v_out) / self.v_dim)
                    stats["o"].append(len(topo.alive_o_out) / self.hs)
                    sliced_elements += (len(topo.alive_q_out) + len(topo.alive_k_out) + len(topo.alive_v_out)) * self.hs
                    sliced_elements += len(topo.alive_o_out) * (topo.n_alive_q_heads * self.hd)
                    full_elements += (self.q_dim + self.k_dim + self.v_dim + self.hs) * self.hs

                if topo.mode in ("both", "mlp"):
                    channels = self._sample_groups(
                        rng, self.n_cg, self.a_cg, self.inter, self.min_cs, self.max_cs, device
                    )
                    core_indices = self._subsample(channels, self.mgg, self.min_mgs, self.max_mgs, rng)
                    topo.alive_gate_out = core_indices
                    topo.alive_up_out = core_indices
                    topo.alive_down_in = core_indices
                    topo.alive_down_out = self._sample_groups(
                        rng, self.n_hg, self.a_hg, self.hs, self.min_hds, self.max_hds, device
                    )
                    stats["gate"].append(len(topo.alive_gate_out) / self.inter)
                    stats["up"].append(len(topo.alive_up_out) / self.inter)
                    stats["down"].append(len(topo.alive_down_out) / self.hs)
                    sliced_elements += (len(topo.alive_gate_out) + len(topo.alive_up_out)) * self.hs
                    sliced_elements += len(topo.alive_down_out) * len(topo.alive_down_in)
                    full_elements += 3 * self.inter * self.hs

                self.topologies[layer] = topo

            active_count = sum(1 for layer in self.victims if layer in self.active_layers)
            mode_counts = {"both": 0, "attn": 0, "mlp": 0, "identity": 0, "dead": 0}
            for layer in self.victims:
                mode_counts[modes[layer]] += 1
            avg = lambda values: sum(values) / len(values) if values else 0.0
            compute_ratio = sliced_elements / max(1, full_elements)
            self.cached_stats = {
                "shuffle_step": global_step,
                "shuffle_block": block_idx,
                "active_layers": active_count,
                "layer_density_pct": 100 * active_count / max(1, num_victims),
                "mode_both": mode_counts["both"],
                "mode_attn_only": mode_counts["attn"],
                "mode_mlp_only": mode_counts["mlp"],
                "mode_identity": mode_counts["identity"],
                "mode_dead": mode_counts["dead"],
                "avg_q_surv": 100 * avg(stats["q"]),
                "avg_k_surv": 100 * avg(stats["k"]),
                "avg_v_surv": 100 * avg(stats["v"]),
                "avg_o_surv": 100 * avg(stats["o"]),
                "avg_gate_surv": 100 * avg(stats["gate"]),
                "avg_up_surv": 100 * avg(stats["up"]),
                "avg_down_surv": 100 * avg(stats["down"]),
                "compute_ratio": compute_ratio,
                "compute_pct": 100 * compute_ratio,
            }
            print(
                "Deep Chaos Shuffle "
                f"@ step {global_step} block={block_idx} "
                f"compute={self.cached_stats['compute_pct']:.1f}% "
                f"active={active_count} "
                f"both={mode_counts['both']} attn={mode_counts['attn']} "
                f"mlp={mode_counts['mlp']} identity={mode_counts['identity']} dead={mode_counts['dead']}"
            )
            return self.cached_stats

        def _make_forward(self, layer_idx, layer):
            scheduler = self

            def forward(
                hidden_states,
                attention_mask=None,
                position_ids=None,
                past_key_value=None,
                output_attentions=False,
                use_cache=False,
                cache_position=None,
                position_embeddings=None,
                **kwargs,
            ):
                if not scheduler.model.training:
                    return scheduler.original_forwards[layer_idx](
                        hidden_states,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_value,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                        **kwargs,
                    )

                topo = scheduler.topologies.get(layer_idx)
                if topo is None or topo.mode in ("dead", "identity"):
                    return (hidden_states, None, past_key_value)

                residual = hidden_states
                batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1]

                if topo.mode in ("both", "attn"):
                    normed = layer.input_layernorm(hidden_states)
                    q = SlicedLinear.forward(layer.self_attn.q_proj, normed, alive_out=topo.alive_q_out, compact=True)
                    k = SlicedLinear.forward(layer.self_attn.k_proj, normed, alive_out=topo.alive_k_out, compact=True)
                    v = SlicedLinear.forward(layer.self_attn.v_proj, normed, alive_out=topo.alive_v_out, compact=True)
                    n_alive_q = len(topo.alive_q_heads or [])
                    n_alive_kv = len(topo.alive_kv_heads or [])
                    expected_q_width = n_alive_q * scheduler.hd
                    expected_kv_width = n_alive_kv * scheduler.hd
                    if q.shape[-1] != expected_q_width:
                        raise RuntimeError(
                            f"Layer {layer_idx}: q width mismatch {q.shape[-1]} != {expected_q_width}"
                        )
                    if k.shape[-1] != expected_kv_width:
                        raise RuntimeError(
                            f"Layer {layer_idx}: k width mismatch {k.shape[-1]} != {expected_kv_width}"
                        )
                    if v.shape[-1] != expected_kv_width:
                        raise RuntimeError(
                            f"Layer {layer_idx}: v width mismatch {v.shape[-1]} != {expected_kv_width}"
                        )

                    q_4d = q.view(batch_size, seq_len, n_alive_q, scheduler.hd).transpose(1, 2)
                    k_4d = k.view(batch_size, seq_len, n_alive_kv, scheduler.hd).transpose(1, 2)
                    v_4d = v.view(batch_size, seq_len, n_alive_kv, scheduler.hd).transpose(1, 2)

                    if scheduler.gqa_r > 1:
                        k_4d = k_4d.repeat_interleave(scheduler.gqa_r, dim=1)
                        v_4d = v_4d.repeat_interleave(scheduler.gqa_r, dim=1)
                    if k_4d.shape[1] != q_4d.shape[1] or v_4d.shape[1] != q_4d.shape[1]:
                        raise RuntimeError(
                            f"Layer {layer_idx}: GQA head mismatch "
                            f"q={q_4d.shape[1]} k={k_4d.shape[1]} v={v_4d.shape[1]}"
                        )
                    if position_embeddings is not None:
                        cos, sin = position_embeddings
                        q_4d, k_4d = scheduler._apply_rotary(q_4d, k_4d, cos, sin, position_ids)

                    attn_mask_4d = None
                    if attention_mask is not None and attention_mask.dim() == 4:
                        if attention_mask.shape[1] in (1, n_alive_q):
                            attn_mask_4d = attention_mask
                        elif attention_mask.shape[1] == scheduler.nh:
                            raise RuntimeError(
                                "Head-specific 4D attention masks are not supported with compact deep-chaos heads"
                            )
                        else:
                            attn_mask_4d = attention_mask
                    elif attention_mask is not None and attention_mask.dim() == 2:
                        attn_mask_4d = attention_mask[:, None, None, :].to(dtype=q_4d.dtype)
                        attn_mask_4d = (1.0 - attn_mask_4d) * torch.finfo(q_4d.dtype).min

                    attn_out = F.scaled_dot_product_attention(
                        q_4d,
                        k_4d,
                        v_4d,
                        attn_mask=attn_mask_4d,
                        dropout_p=0.0,
                        is_causal=(attn_mask_4d is None),
                    )
                    if attn_out.shape[1] != n_alive_q or attn_out.shape[3] != scheduler.hd:
                        raise RuntimeError(
                            f"Layer {layer_idx}: attention output shape mismatch {tuple(attn_out.shape)}"
                        )
                    attn_compact = attn_out.transpose(1, 2).contiguous().reshape(batch_size, seq_len, -1)
                    alive_o_in = scheduler._heads_to_indices(topo.alive_q_heads, scheduler.hd, hidden_states.device)
                    if attn_compact.shape[-1] != alive_o_in.numel():
                        raise RuntimeError(
                            f"Layer {layer_idx}: compact attention width mismatch "
                            f"{attn_compact.shape[-1]} != {alive_o_in.numel()}"
                        )
                    o_weight = layer.self_attn.o_proj.weight[topo.alive_o_out][:, alive_o_in]
                    o_bias = (
                        layer.self_attn.o_proj.bias[topo.alive_o_out]
                        if layer.self_attn.o_proj.bias is not None
                        else None
                    )
                    attn_compact_out = F.linear(attn_compact, o_weight, o_bias)
                    attn_hidden = hidden_states.new_zeros(batch_size, seq_len, scheduler.hs)
                    attn_hidden[..., topo.alive_o_out] = attn_compact_out
                    hidden_states = residual + attn_hidden
                else:
                    hidden_states = residual

                residual = hidden_states
                if topo.mode in ("both", "mlp"):
                    normed = layer.post_attention_layernorm(hidden_states)
                    gate = SlicedLinear.forward(layer.mlp.gate_proj, normed, alive_out=topo.alive_gate_out, compact=True)
                    up = SlicedLinear.forward(layer.mlp.up_proj, normed, alive_out=topo.alive_up_out, compact=True)
                    if gate.shape[-1] != topo.alive_down_in.numel() or up.shape[-1] != topo.alive_down_in.numel():
                        raise RuntimeError(
                            f"Layer {layer_idx}: SwiGLU core mismatch "
                            f"gate={gate.shape[-1]} up={up.shape[-1]} core={topo.alive_down_in.numel()}"
                        )
                    activated = F.silu(gate) * up

                    weight = layer.mlp.down_proj.weight[topo.alive_down_out][:, topo.alive_down_in]
                    bias = (
                        layer.mlp.down_proj.bias[topo.alive_down_out]
                        if layer.mlp.down_proj.bias is not None
                        else None
                    )
                    mlp_compact = F.linear(activated, weight, bias)
                    mlp_hidden = hidden_states.new_zeros(batch_size, seq_len, scheduler.hs)
                    mlp_hidden[..., topo.alive_down_out] = mlp_compact
                    hidden_states = residual + mlp_hidden
                else:
                    hidden_states = residual

                return (hidden_states, None, past_key_value)

            return forward

        def freeze_topology(self, step):
            self.step(step)

        def remove(self):
            layers = resolve_transformer_layers(self.model)
            for idx, original in self.original_forwards.items():
                layers[idx].forward = original
            self.original_forwards.clear()
            self.topologies.clear()
            self.cached_stats = None
            self.last_shuffle_step = None
            print("DeepChaosScheduler removed")

    deep_chaos = None
    if C.DEEP_CHAOS_ENABLED:
        deep_chaos = DeepChaosScheduler(
            model,
            C.DEEP_CHAOS_SACRED_LAYERS,
            C.DEEP_CHAOS_VICTIM_RANGE,
            C.DEEP_CHAOS_MIN_LAYER_SURVIVAL,
            C.DEEP_CHAOS_MAX_LAYER_SURVIVAL,
            C.DEEP_CHAOS_MIN_HEAD_SURVIVAL,
            C.DEEP_CHAOS_MAX_HEAD_SURVIVAL,
            C.DEEP_CHAOS_CHANNEL_GROUP_SIZE,
            C.DEEP_CHAOS_MIN_CHANNEL_SURVIVAL,
            C.DEEP_CHAOS_MAX_CHANNEL_SURVIVAL,
            C.DEEP_CHAOS_MLP_GATE_GROUP_SIZE,
            C.DEEP_CHAOS_MIN_MLP_GATE_SURVIVAL,
            C.DEEP_CHAOS_MAX_MLP_GATE_SURVIVAL,
            C.DEEP_CHAOS_HIDDEN_GROUP_SIZE,
            C.DEEP_CHAOS_MIN_HIDDEN_SURVIVAL,
            C.DEEP_CHAOS_MAX_HIDDEN_SURVIVAL,
            C.DEEP_CHAOS_MAX_CONSECUTIVE_ON,
            C.DEEP_CHAOS_MAX_CONSECUTIVE_OFF,
            C.DEEP_CHAOS_STICKY_INTERVAL,
            C.DEEP_CHAOS_SEED,
        )

# fill these in as you see fit this was just for my bartender personality model
    PROBE_TOKENS = {
        "related": ["gritty", "casual", "weird", "bartender", "conversation", "direct", "physical", "grounded"],
        "unrelated": ["certainly", "apologize", "assist", "helpful", "serving", "comply", "obedient", "dutifully"],
    }

    def get_token_ids(tok, words):
        ids = []
        for word in words:
            encoded = tok.encode(word, add_special_tokens=False)
            if encoded:
                ids.append(encoded[0])
        return ids

    PROBE_RELATED_IDS = get_token_ids(tokenizer, PROBE_TOKENS["related"])
    PROBE_UNRELATED_IDS = get_token_ids(tokenizer, PROBE_TOKENS["unrelated"])
    PROBE_ALL_IDS = PROBE_RELATED_IDS + PROBE_UNRELATED_IDS

    def _capture_probe_state():
        if injector_mgr is None:
            return None
        noise_scale = None
        if injector_mgr.injectors:
            noise_scale = next(iter(injector_mgr.injectors.values())).noise_scale
        return {
            "noise_scale": noise_scale,
            "assistant_mask": injector_mgr.current_assistant_mask,
        }

    def _disable_injectors_for_probe():
        if injector_mgr is None:
            return
        injector_mgr.set_noise_scale(0.0)
        injector_mgr.current_assistant_mask = None
        injector_mgr.clear_cache()

    def _restore_probe_state(state):
        if injector_mgr is None or state is None:
            return
        if state["noise_scale"] is not None:
            injector_mgr.set_noise_scale(state["noise_scale"])
        injector_mgr.current_assistant_mask = state["assistant_mask"]

    def extract_layer_embeddings(target_model, tok, token_ids, layer_range, use_chaos=False):
        activations = {}
        hooks = []
        layers = resolve_transformer_layers(target_model)
        was_training = target_model.training
        probe_state = _capture_probe_state()
        if use_chaos:
            target_model.train()
        else:
            target_model.eval()
        _disable_injectors_for_probe()

        def make_hook(layer_idx):
            def hook_fn(module, inputs, output):
                hidden = output[0] if isinstance(output, tuple) else output
                activations[layer_idx] = hidden.detach().float().cpu().numpy()
            return hook_fn

        for layer_idx in range(layer_range[0], layer_range[1] + 1):
            hooks.append(layers[layer_idx].register_forward_hook(make_hook(layer_idx)))

        input_ids = torch.tensor([token_ids], device=model_device)
        with torch.no_grad():
            target_model(input_ids, return_dict=True)

        for hook in hooks:
            hook.remove()
        _restore_probe_state(probe_state)
        if was_training:
            target_model.train()
        else:
            target_model.eval()
        return activations

    def compute_sim_matrix(embeddings):
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        norms = np.where(norms == 0, 1, norms)
        normed = embeddings / norms
        return normed @ normed.T

    def compute_pair_sim(a, b):
        a_flat = a.reshape(-1).astype(np.float32)
        b_flat = b.reshape(-1).astype(np.float32)
        a_norm = np.linalg.norm(a_flat)
        b_norm = np.linalg.norm(b_flat)
        return float(np.dot(a_flat, b_flat) / (a_norm * b_norm)) if a_norm > 0 and b_norm > 0 else 0.0

    def run_bol_probe(target_model, tok, step, layer_range=C.PROBE_LAYER_RANGE, use_chaos=False):
        related = get_token_ids(tok, PROBE_TOKENS["related"])
        unrelated = get_token_ids(tok, PROBE_TOKENS["unrelated"])
        all_ids = related + unrelated
        n_related = len(related)
        embeddings = extract_layer_embeddings(target_model, tok, all_ids, layer_range, use_chaos=use_chaos)
        results = {"step": step, "layers": {}, "pairs": {}, "mode": "chaos" if use_chaos else "dense"}
        for layer_idx in sorted(embeddings.keys()):
            emb = embeddings[layer_idx][0]
            sim = compute_sim_matrix(emb)
            related_block = sim[:n_related, :n_related]
            unrelated_block = sim[n_related:, n_related:]
            mask_r = ~np.eye(n_related, dtype=bool)
            mask_u = ~np.eye(len(unrelated), dtype=bool)
            mean_r = related_block[mask_r].mean() if mask_r.sum() > 0 else 0.0
            mean_u = unrelated_block[mask_u].mean() if mask_u.sum() > 0 else 0.0
            results["layers"][layer_idx] = {
                "identity_sim": float(mean_r),
                "deference_sim": float(mean_u),
                "separation": float(mean_r - mean_u),
                "activation_variance": float(np.var(emb)),
            }
        injector_layers = [layer for layer in C.INJECTOR_LAYERS if layer in embeddings]
        late_layers = [layer for layer in range(max(layer_range[1] - 2, layer_range[0]), layer_range[1] + 1) if layer in embeddings]
        for group_name, group_layers in {"injector": injector_layers, "late": late_layers}.items():
            for i, layer_a in enumerate(group_layers):
                for layer_b in group_layers[i + 1 :]:
                    results["pairs"][f"{group_name}_{layer_a}_vs_{layer_b}"] = compute_pair_sim(
                        embeddings[layer_a][0],
                        embeddings[layer_b][0],
                    )
        injector_vals = [value for key, value in results["pairs"].items() if key.startswith("injector_")]
        late_vals = [value for key, value in results["pairs"].items() if key.startswith("late_")]
        results["summary"] = {
            "mean_separation": float(np.mean([v["separation"] for v in results["layers"].values()])),
            "mean_injector_redundancy": float(np.mean(injector_vals)) if injector_vals else 0.0,
            "mean_late_redundancy": float(np.mean(late_vals)) if late_vals else 0.0,
        }
        return results

    def print_probe_report(results):
        print(f"\n{'=' * 65}")
        print(f"BoL PROBE [{results.get('mode', 'dense').upper()}] @ step {results['step']}")
        print(f"{'=' * 65}")
        print(f"  Mean Separation: {results['summary']['mean_separation']:.4f}")
        print(f"  Injector Redundancy: {results['summary']['mean_injector_redundancy']:.4f}")
        print(f"  Late Redundancy: {results['summary']['mean_late_redundancy']:.4f}")
        print(f"{'_' * 65}")
        for layer_idx in sorted(results["layers"].keys()):
            values = results["layers"][layer_idx]
            inj = " <-INJ" if layer_idx in C.INJECTOR_LAYERS else ""
            print(
                f"  {layer_idx:<4} id={values['identity_sim']:.4f} def={values['deference_sim']:.4f} "
                f"sep={values['separation']:.4f} var={values['activation_variance']:.4f}{inj}"
            )
        print(f"{'=' * 65}\n")

    class ProbeCallback(TrainerCallback):
        def __init__(self, tok, probe_every, deep_chaos_ref=None):
            self.tokenizer = tok
            self.probe_every = probe_every
            self.deep_chaos = deep_chaos_ref

        def on_step_end(self, args, state, control, model=None, **kwargs):
            if not C.PROBE_ENABLED or state.global_step == 0 or state.global_step % self.probe_every != 0:
                return control
            dense = run_bol_probe(model, self.tokenizer, state.global_step, use_chaos=False)
            print_probe_report(dense)
            chaos_results = None
            if self.deep_chaos is not None:
                self.deep_chaos.freeze_topology(state.global_step)
                chaos_results = run_bol_probe(model, self.tokenizer, state.global_step, use_chaos=True)
                print_probe_report(chaos_results)
            payload = {"dense": dense}
            if chaos_results is not None:
                payload["chaos"] = chaos_results
            Path(C.OUTPUT_DIR, f"probe_step_{state.global_step}.json").write_text(json.dumps(payload, indent=2))
            volume.commit()
            if wandb.run:
                wandb.log(
                    {
                        "probe_dense/separation": dense["summary"]["mean_separation"],
                        "probe_dense/late_redundancy": dense["summary"]["mean_late_redundancy"],
                    },
                    step=state.global_step,
                )
                if chaos_results is not None:
                    wandb.log(
                        {
                            "probe_chaos/separation": chaos_results["summary"]["mean_separation"],
                            "probe_chaos/late_redundancy": chaos_results["summary"]["mean_late_redundancy"],
                        },
                        step=state.global_step,
                    )
            return control

    class BoundaryProbeCallback(TrainerCallback):
        def __init__(self, tok, probe_every, deep_chaos_ref=None):
            self.tokenizer = tok
            self.probe_every = probe_every
            self.deep_chaos = deep_chaos_ref
            self.prompts = [
            ]
            self.forbidden_ids = set()
            for text in ["<|im_start|>user", "<|im_start|>assistant", "user", "assistant", "User", "Assistant"]:
                self.forbidden_ids.update(self.tokenizer.encode(text, add_special_tokens=False))

        def _run_boundary(self, target_model):
            max_forbidden = 0.0
            for prompt in self.prompts:
                inputs = self.tokenizer(prompt, return_tensors="pt").to(model_device)
                with torch.no_grad():
                    probs = torch.softmax(target_model(**inputs).logits[0, -1, :], dim=-1)
                max_forbidden = max(
                    max_forbidden,
                    sum(probs[token_id].item() for token_id in self.forbidden_ids if token_id < probs.shape[0]),
                )
            return max_forbidden

        def on_step_end(self, args, state, control, model=None, **kwargs):
            if model is None or state.global_step == 0 or state.global_step % self.probe_every != 0:
                return control
            was_training = model.training
            probe_state = _capture_probe_state()
            try:
                model.eval()
                _disable_injectors_for_probe()
                dense_forbidden = self._run_boundary(model)
                print(f"Boundary [DENSE] @ {state.global_step}: max_forbidden={dense_forbidden:.6f}")
                chaos_forbidden = 0.0
                if self.deep_chaos is not None:
                    model.train()
                    self.deep_chaos.freeze_topology(state.global_step)
                    _disable_injectors_for_probe()
                    chaos_forbidden = self._run_boundary(model)
                    print(f"Boundary [CHAOS] @ {state.global_step}: max_forbidden={chaos_forbidden:.6f}")
                if wandb.run:
                    wandb.log(
                        {
                            "boundary/dense_max_forbidden": dense_forbidden,
                            "boundary/chaos_max_forbidden": chaos_forbidden,
                        },
                        step=state.global_step,
                    )
            finally:
                _restore_probe_state(probe_state)
                if was_training:
                    model.train()
                else:
                    model.eval()
            return control

    if not getattr(tokenizer, "chat_template", None):
        tokenizer = get_chat_template(
            tokenizer,
            chat_template="llama-3",
            mapping={"role": "role", "content": "content", "user": "user", "assistant": "assistant"},
        )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    dataset = load_dataset(C.DATASET_NAME, split=C.DATASET_SPLIT).shuffle(seed=42)
    print(f"Dataset: {C.DATASET_NAME} ({len(dataset)} samples)")

    def tokenize_conversation(convo):
        return tokenizer.apply_chat_template(
            convo,
            tokenize=True,
            add_generation_prompt=False,
            return_dict=True,
            truncation=True,
            max_length=C.MAX_SEQ_LENGTH,
        )

    def build_labels(convo, input_ids):
        labels = [-100] * len(input_ids)
        for idx, message in enumerate(convo):
            if message.get("role") != "assistant":
                continue
            prefix_ids = tokenize_conversation(convo[:idx])["input_ids"] if idx > 0 else []
            through_ids = tokenize_conversation(convo[: idx + 1])["input_ids"]
            for pos in range(min(len(prefix_ids), len(input_ids)), min(len(through_ids), len(input_ids))):
                labels[pos] = input_ids[pos]
        return labels

    def format_batch(examples):
        input_ids_list = []
        attention_masks_list = []
        labels_list = []
        for convo in examples["messages"]:
            tokenized = tokenize_conversation(convo)
            input_ids_list.append(tokenized["input_ids"])
            attention_masks_list.append(tokenized["attention_mask"])
            labels_list.append(build_labels(convo, tokenized["input_ids"]))
        return {
            "input_ids": input_ids_list,
            "attention_mask": attention_masks_list,
            "labels": labels_list,
        }

    dataset_formatted = standardize_sharegpt(dataset)
    dataset_formatted = dataset_formatted.map(format_batch, batched=True, remove_columns=dataset_formatted.column_names)

    os.makedirs(C.OUTPUT_DIR, exist_ok=True)
    if C.PROBE_ENABLED:
        print("Baseline probe (dense)...")
        baseline_dense = run_bol_probe(model, tokenizer, 0, use_chaos=False)
        print_probe_report(baseline_dense)
        baseline_chaos = None
        if deep_chaos is not None:
            print("Baseline probe (chaos)...")
            deep_chaos.freeze_topology(0)
            baseline_chaos = run_bol_probe(model, tokenizer, 0, use_chaos=True)
            print_probe_report(baseline_chaos)
        Path(C.OUTPUT_DIR, "probe_baseline.json").write_text(
            json.dumps({"dense": baseline_dense, "chaos": baseline_chaos}, indent=2)
        )

    wandb.init(project=C.WANDB_PROJECT, name=C.WANDB_RUN_NAME)
    wandb.config.update(config)

    class DeepChaosSFTTrainer(Trainer):
        injector_mgr = None
        base_model_ref = None
        deep_chaos = None
        deep_chaos_stats = None

        def compute_loss(self, target_model, inputs, return_outputs=False, **kwargs):
            if self.deep_chaos is not None:
                self.deep_chaos_stats = self.deep_chaos.step(self.state.global_step)

            labels = inputs.get("labels")
            model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
            assistant_mask = None
            clean_cache = {}

            if self.injector_mgr is not None and labels is not None:
                assistant_mask = (labels != -100).float()
                self.injector_mgr.current_assistant_mask = assistant_mask
                self.injector_mgr.set_noise_scale(C.INJECTOR_NOISE_SCALE)
                if C.STABILITY_LOSS_WEIGHT > 0:
                    self.injector_mgr.set_noise_scale(0.0)
                    with torch.no_grad():
                        target_model(**model_inputs)
                    clean_cache = self.injector_mgr.snapshot_cache(use_post=C.STABILITY_ON_POST_INJECTION, detach=True)
                    self.injector_mgr.clear_cache()
                    self.injector_mgr.current_assistant_mask = assistant_mask
                    self.injector_mgr.set_noise_scale(C.INJECTOR_NOISE_SCALE)

            outputs = target_model(**model_inputs)
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = nn.CrossEntropyLoss(ignore_index=-100)(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            ortho = torch.tensor(0.0, device=target_model.device)
            stab = torch.tensor(0.0, device=target_model.device)
            pres = torch.tensor(0.0, device=target_model.device)
            decorr = torch.tensor(0.0, device=target_model.device)

            if self.injector_mgr is not None:
                ortho = self.injector_mgr.compute_ortho_loss(use_post=not C.ORTHO_ON_PRE_INJECTION)
                decorr = self.injector_mgr.compute_decorrelation_loss(mask=assistant_mask, use_post=True)
                if clean_cache:
                    perturbed = self.injector_mgr.snapshot_cache(use_post=C.STABILITY_ON_POST_INJECTION, detach=False)
                    stab = self.injector_mgr.compute_stability_loss(clean_cache, perturbed, assistant_mask)

            if self.base_model_ref is not None and C.PRESERVATION_KL_WEIGHT > 0 and labels is not None:
                with torch.no_grad():
                    base_outputs = self.base_model_ref(**model_inputs)
                base_logits = base_outputs.logits[..., :-1, :].contiguous()
                tuned_logits = outputs.logits[..., :-1, :].contiguous()
                policy_mask = (labels[..., 1:].contiguous() == -100).float()
                if policy_mask.sum().item() > 0:
                    log_p_base = F.log_softmax(base_logits, dim=-1)
                    log_p_tuned = F.log_softmax(tuned_logits, dim=-1)
                    kl = (log_p_base.exp() * (log_p_base - log_p_tuned)).sum(-1)
                    pres = (kl * policy_mask).sum() / policy_mask.sum().clamp_min(1)

            main_loss = (
                loss
                + ortho
                + C.STABILITY_LOSS_WEIGHT * stab
                + C.PRESERVATION_KL_WEIGHT * pres
                + C.DECORRELATION_WEIGHT * decorr
            )

            if self.injector_mgr is not None:
                self.injector_mgr.clear_cache()

            if self.state.global_step % C.LOGGING_STEPS == 0 and wandb.run:
                payload = {
                    "loss/main": float(main_loss),
                    "loss/task": float(loss),
                    "ortho/loss": float(ortho),
                    "stability/loss": float(stab),
                    "preservation/kl": float(pres),
                    "decorrelation/loss": float(decorr),
                }
                if self.deep_chaos_stats is not None:
                    payload.update({f"dc/{k}": float(v) for k, v in self.deep_chaos_stats.items()})
                wandb.log(payload, step=self.state.global_step)

            return (main_loss, outputs) if return_outputs else main_loss

        def create_optimizer(self):
            super().create_optimizer()
            if self.injector_mgr is not None:
                injector_params = self.injector_mgr.get_trainable_params()
                if injector_params:
                    self.optimizer.add_param_group({"params": injector_params, "lr": C.LEARNING_RATE, "weight_decay": 0.0})
            return self.optimizer

    training_args = TrainingArguments(
        output_dir=C.OUTPUT_DIR,
        num_train_epochs=C.NUM_EPOCHS,
        per_device_train_batch_size=C.BATCH_SIZE,
        gradient_accumulation_steps=C.GRADIENT_ACCUMULATION,
        learning_rate=C.LEARNING_RATE,
        lr_scheduler_type=C.LR_SCHEDULER,
        warmup_ratio=C.WARMUP_RATIO,
        weight_decay=C.WEIGHT_DECAY,
        max_grad_norm=C.MAX_GRAD_NORM,
        optim="adamw_torch",
        bf16=is_bfloat16_supported(),
        fp16=not is_bfloat16_supported(),
        logging_steps=C.LOGGING_STEPS,
        save_steps=C.SAVE_STEPS,
        save_total_limit=2,
        gradient_checkpointing=True,
        report_to="wandb",
        run_name=C.WANDB_RUN_NAME,
        remove_unused_columns=False,
        seed=42,
    )

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True,
        label_pad_token_id=-100,
        return_tensors="pt",
    )

    trainer = DeepChaosSFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset_formatted,
        processing_class=tokenizer,
        data_collator=data_collator,
        callbacks=[
            ProbeCallback(tokenizer, C.PROBE_STEPS, deep_chaos),
            BoundaryProbeCallback(tokenizer, C.PROBE_STEPS, deep_chaos),
        ],
    )
    trainer.injector_mgr = injector_mgr
    trainer.base_model_ref = base_model_ref
    trainer.deep_chaos = deep_chaos

    sanity_batch = next(iter(trainer.get_train_dataloader()))
    supervised_tokens = (sanity_batch["labels"] != -100).sum().item()
    total_tokens = sanity_batch["labels"].numel()
    if supervised_tokens <= 0 or supervised_tokens >= total_tokens:
        raise RuntimeError(f"Label sanity failed: sup={supervised_tokens} total={total_tokens}")
    print(f"Label sanity: supervised={supervised_tokens} masked={total_tokens - supervised_tokens}")

    if deep_chaos is not None:
        print("\nSmoke test...")
        smoke_stats = deep_chaos.step(0)
        print(f"  Compute est: {smoke_stats['compute_pct']:.1f}%  Layers: {smoke_stats['active_layers']}")
        smoke_inputs = sanity_batch["input_ids"][:1, :64].to(model_device)
        smoke_mask = sanity_batch["attention_mask"][:1, :64].to(model_device)
        model.train()
        with torch.no_grad():
            smoke_outputs = model(input_ids=smoke_inputs, attention_mask=smoke_mask, return_dict=True)
        if not torch.isfinite(smoke_outputs.logits).all():
            raise RuntimeError("Smoke test produced non-finite logits")
        print(f"  Forward pass OK: {smoke_outputs.logits.shape}")

    print(f"\n{'=' * 60}")
    print("READY — DEEP CHAOS v3")
    print(f"{'=' * 60}")
    print(f"  Model: {C.MODEL_NAME}")
    print(f"  Dataset: {len(dataset_formatted)} samples")
    print(f"  Deep Chaos: {C.DEEP_CHAOS_ENABLED}")
    print(f"  Train layers: {freezing_info['train_start']}-{freezing_info['train_end']} of {freezing_info['total_layers'] - 1}")
    print(f"  Export: {C.HF_REPO}")
    print(f"{'=' * 60}")

    trainer_stats = trainer.train()
    print(f"\nDone. Loss: {trainer_stats.metrics['train_loss']:.4f}")

    if injector_mgr is not None:
        injector_mgr.remove_hooks()
    if deep_chaos is not None:
        deep_chaos.remove()

    final_dir = os.path.join(C.OUTPUT_DIR, "final")
    model.save_pretrained(final_dir, safe_serialization=True)
    tokenizer.save_pretrained(final_dir)
    volume.commit()

    api = HfApi()
    api.create_repo(repo_id=C.HF_REPO, private=True, exist_ok=True)
    api.upload_folder(repo_id=C.HF_REPO, repo_type="model", folder_path=final_dir)
    print(f"Hub push complete: {C.HF_REPO}")


@app.local_entrypoint()
def main():
    print("")
    train.remote()

Community

Sign up or log in to comment